diff --git a/scripts/conversion/convert_diffusers_autoencoder_kl.py b/scripts/conversion/convert_diffusers_autoencoder_kl.py index d53f1be..afb1c26 100644 --- a/scripts/conversion/convert_diffusers_autoencoder_kl.py +++ b/scripts/conversion/convert_diffusers_autoencoder_kl.py @@ -1,10 +1,12 @@ import argparse from pathlib import Path + import torch -from torch import nn from diffusers import AutoencoderKL # type: ignore -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder +from torch import nn + from refiners.fluxion.model_converter import ModelConverter +from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder class Args(argparse.Namespace): diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index 11493cb..2a6f742 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -1,15 +1,17 @@ # pyright: reportPrivateUsage=false import argparse from pathlib import Path + import torch -from torch import nn from diffusers import ControlNetModel # type: ignore -from refiners.fluxion.utils import save_to_safetensors +from torch import nn + from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import save_to_safetensors from refiners.foundationals.latent_diffusion import ( - SD1UNet, - SD1ControlnetAdapter, DPMSolver, + SD1ControlnetAdapter, + SD1UNet, ) diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 0d48660..8fd33be 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -1,11 +1,11 @@ +import argparse from pathlib import Path from typing import Any -import argparse import torch -from refiners.foundationals.latent_diffusion import SD1UNet, SD1IPAdapter, SDXLUNet, SDXLIPAdapter from refiners.fluxion.utils import save_to_safetensors +from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet # Running: # diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py index 07d8fdc..b106794 100644 --- a/scripts/conversion/convert_diffusers_lora.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -3,16 +3,15 @@ from pathlib import Path from typing import cast import torch -from torch import Tensor -from torch.nn.init import zeros_ -from torch.nn import Parameter as TorchParameter - from diffusers import DiffusionPipeline # type: ignore +from torch import Tensor +from torch.nn import Parameter as TorchParameter +from torch.nn.init import zeros_ import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.lora import Lora, LoraAdapter from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import save_to_safetensors -from refiners.fluxion.adapters.lora import Lora, LoraAdapter from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets diff --git a/scripts/conversion/convert_diffusers_t2i_adapter.py b/scripts/conversion/convert_diffusers_t2i_adapter.py index 814fd22..9ac4a8c 100644 --- a/scripts/conversion/convert_diffusers_t2i_adapter.py +++ b/scripts/conversion/convert_diffusers_t2i_adapter.py @@ -1,10 +1,12 @@ import argparse from pathlib import Path + import torch -from torch import nn from diffusers import T2IAdapter # type: ignore -from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, ConditionEncoderXL +from torch import nn + from refiners.fluxion.model_converter import ModelConverter +from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, ConditionEncoderXL if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert a pretrained diffusers T2I-Adapter model to refiners") diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index ca9ad01..9c6f257 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -1,9 +1,11 @@ import argparse from pathlib import Path + import torch -from torch import nn -from refiners.fluxion.model_converter import ModelConverter from diffusers import UNet2DConditionModel # type: ignore +from torch import nn + +from refiners.fluxion.model_converter import ModelConverter from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet diff --git a/scripts/conversion/convert_informative_drawings.py b/scripts/conversion/convert_informative_drawings.py index 75f4000..cb20f36 100644 --- a/scripts/conversion/convert_informative_drawings.py +++ b/scripts/conversion/convert_informative_drawings.py @@ -1,7 +1,9 @@ import argparse from typing import TYPE_CHECKING, cast + import torch from torch import nn + from refiners.fluxion.model_converter import ModelConverter from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings diff --git a/scripts/conversion/convert_refiners_lora_to_sdwebui.py b/scripts/conversion/convert_refiners_lora_to_sdwebui.py index 839026c..e8105ea 100644 --- a/scripts/conversion/convert_refiners_lora_to_sdwebui.py +++ b/scripts/conversion/convert_refiners_lora_to_sdwebui.py @@ -1,20 +1,22 @@ import argparse from functools import partial + +from convert_diffusers_unet import Args as UnetConversionArgs, setup_converter as convert_unet +from convert_transformers_clip_text_model import ( + Args as TextEncoderConversionArgs, + setup_converter as convert_text_encoder, +) from torch import Tensor + +import refiners.fluxion.layers as fl from refiners.fluxion.utils import ( load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors, ) -from convert_diffusers_unet import setup_converter as convert_unet, Args as UnetConversionArgs -from convert_transformers_clip_text_model import ( - setup_converter as convert_text_encoder, - Args as TextEncoderConversionArgs, -) from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget -import refiners.fluxion.layers as fl def get_unet_mapping(source_path: str) -> dict[str, str]: diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index b43d575..9057cbb 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -1,20 +1,19 @@ import argparse import types from typing import Any, Callable, cast + import torch import torch.nn as nn +from segment_anything import build_sam_vit_h # type: ignore +from segment_anything.modeling.common import LayerNorm2d # type: ignore from torch import Tensor import refiners.fluxion.layers as fl from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import manual_seed, save_to_safetensors from refiners.foundationals.segment_anything.image_encoder import SAMViTH -from refiners.foundationals.segment_anything.prompt_encoder import PointEncoder, MaskEncoder - -from segment_anything import build_sam_vit_h # type: ignore -from segment_anything.modeling.common import LayerNorm2d # type: ignore - from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder +from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder class FacebookSAM(nn.Module): @@ -134,9 +133,10 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]: point_embedding = torch.randn(1, 3, 256) mask_embedding = torch.randn(1, 256, 64, 64) - import refiners.fluxion.layers as fl from segment_anything.modeling.common import LayerNorm2d # type: ignore + import refiners.fluxion.layers as fl + assert issubclass(LayerNorm2d, nn.Module) custom_layers = {LayerNorm2d: fl.LayerNorm2d} diff --git a/scripts/conversion/convert_transformers_clip_image_model.py b/scripts/conversion/convert_transformers_clip_image_model.py index 9ae73d8..d3edd9f 100644 --- a/scripts/conversion/convert_transformers_clip_image_model.py +++ b/scripts/conversion/convert_transformers_clip_image_model.py @@ -1,12 +1,14 @@ import argparse from pathlib import Path -from torch import nn -from refiners.fluxion.model_converter import ModelConverter -from transformers import CLIPVisionModelWithProjection # type: ignore -from refiners.foundationals.clip.image_encoder import CLIPImageEncoder -from refiners.fluxion.utils import save_to_safetensors + import torch +from torch import nn +from transformers import CLIPVisionModelWithProjection # type: ignore + import refiners.fluxion.layers as fl +from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import save_to_safetensors +from refiners.foundationals.clip.image_encoder import CLIPImageEncoder class Args(argparse.Namespace): diff --git a/scripts/conversion/convert_transformers_clip_text_model.py b/scripts/conversion/convert_transformers_clip_text_model.py index bbe8679..2b74092 100644 --- a/scripts/conversion/convert_transformers_clip_text_model.py +++ b/scripts/conversion/convert_transformers_clip_text_model.py @@ -1,14 +1,16 @@ import argparse from pathlib import Path from typing import cast + from torch import nn -from refiners.fluxion.model_converter import ModelConverter from transformers import CLIPTextModelWithProjection # type: ignore -from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderL, CLIPTextEncoderG + +import refiners.fluxion.layers as fl +from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import save_to_safetensors +from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, CLIPTextEncoderG, CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder -from refiners.fluxion.utils import save_to_safetensors -import refiners.fluxion.layers as fl class Args(argparse.Namespace): diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index 11d02e4..73fa359 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -1,20 +1,21 @@ import random from typing import Any -from pydantic import BaseModel + from loguru import logger -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets -import refiners.fluxion.layers as fl +from pydantic import BaseModel from torch import Tensor from torch.utils.data import Dataset +import refiners.fluxion.layers as fl +from refiners.fluxion.utils import save_to_safetensors +from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets from refiners.training_utils.callback import Callback from refiners.training_utils.latent_diffusion import ( FinetuneLatentDiffusionConfig, + LatentDiffusionConfig, + LatentDiffusionTrainer, TextEmbeddingLatentsBatch, TextEmbeddingLatentsDataset, - LatentDiffusionTrainer, - LatentDiffusionConfig, ) diff --git a/scripts/training/finetune-ldm-textual-inversion.py b/scripts/training/finetune-ldm-textual-inversion.py index 4d704b3..66bf44e 100644 --- a/scripts/training/finetune-ldm-textual-inversion.py +++ b/scripts/training/finetune-ldm-textual-inversion.py @@ -1,24 +1,24 @@ -from typing import Any -from pydantic import BaseModel -from loguru import logger -from torch.utils.data import Dataset -from torch import randn, Tensor import random +from typing import Any +from loguru import logger +from pydantic import BaseModel +from torch import Tensor, randn +from torch.utils.data import Dataset + +from refiners.fluxion.utils import save_to_safetensors from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.fluxion.utils import save_to_safetensors from refiners.training_utils.callback import Callback from refiners.training_utils.latent_diffusion import ( FinetuneLatentDiffusionConfig, - TextEmbeddingLatentsBatch, - LatentDiffusionTrainer, LatentDiffusionConfig, + LatentDiffusionTrainer, + TextEmbeddingLatentsBatch, TextEmbeddingLatentsDataset, ) - IMAGENET_TEMPLATES_SMALL = [ "a photo of a {}", "a rendering of a {}", diff --git a/src/refiners/fluxion/__init__.py b/src/refiners/fluxion/__init__.py index dac0ea9..048a6c1 100644 --- a/src/refiners/fluxion/__init__.py +++ b/src/refiners/fluxion/__init__.py @@ -1,3 +1,3 @@ -from refiners.fluxion.utils import save_to_safetensors, load_from_safetensors, norm, manual_seed, pad +from refiners.fluxion.utils import load_from_safetensors, manual_seed, norm, pad, save_to_safetensors __all__ = ["norm", "manual_seed", "save_to_safetensors", "load_from_safetensors", "pad"] diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index 9f27153..5f55d40 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -1,7 +1,7 @@ import contextlib -import refiners.fluxion.layers as fl -from typing import Any, Generic, TypeVar, Iterator +from typing import Any, Generic, Iterator, TypeVar +import refiners.fluxion.layers as fl T = TypeVar("T", bound=fl.Module) TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673) diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index fc0d374..b0f2e89 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -1,11 +1,11 @@ -from typing import Iterable, Generic, TypeVar, Any - -import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.adapter import Adapter +from typing import Any, Generic, Iterable, TypeVar from torch import Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter -from torch.nn.init import zeros_, normal_ +from torch.nn.init import normal_, zeros_ + +import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.adapter import Adapter T = TypeVar("T", bound=fl.Chain) TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673) diff --git a/src/refiners/fluxion/context.py b/src/refiners/fluxion/context.py index 08914a7..e8871e2 100644 --- a/src/refiners/fluxion/context.py +++ b/src/refiners/fluxion/context.py @@ -1,4 +1,5 @@ from typing import Any + from torch import Tensor Context = dict[str, Any] diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index de38cc0..1446244 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -1,50 +1,50 @@ -from refiners.fluxion.layers.activations import GLU, SiLU, ReLU, ApproximateGeLU, GeLU, Sigmoid -from refiners.fluxion.layers.norm import LayerNorm, GroupNorm, LayerNorm2d, InstanceNorm2d +from refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU, Sigmoid, SiLU from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d from refiners.fluxion.layers.basics import ( - Identity, - View, + Buffer, + Chunk, + Cos, Flatten, - Unflatten, - Transpose, GetArg, + Identity, + Multiply, + Parameter, Permute, Reshape, - Squeeze, - Unsqueeze, - Slicing, Sin, - Cos, - Chunk, - Multiply, + Slicing, + Squeeze, + Transpose, Unbind, - Parameter, - Buffer, + Unflatten, + Unsqueeze, + View, ) from refiners.fluxion.layers.chain import ( + Breakpoint, + Chain, + Concatenate, + Distribute, Lambda, - Sum, + Matmul, + Parallel, + Passthrough, Residual, Return, - Chain, - UseContext, SetContext, - Parallel, - Distribute, - Passthrough, - Breakpoint, - Concatenate, - Matmul, + Sum, + UseContext, ) from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d +from refiners.fluxion.layers.converter import Converter +from refiners.fluxion.layers.embedding import Embedding from refiners.fluxion.layers.linear import Linear, MultiLinear -from refiners.fluxion.layers.module import Module, WeightedModule, ContextModule +from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d +from refiners.fluxion.layers.module import ContextModule, Module, WeightedModule +from refiners.fluxion.layers.norm import GroupNorm, InstanceNorm2d, LayerNorm, LayerNorm2d from refiners.fluxion.layers.padding import ReflectionPad2d from refiners.fluxion.layers.pixelshuffle import PixelUnshuffle -from refiners.fluxion.layers.sampling import Downsample, Upsample, Interpolate -from refiners.fluxion.layers.embedding import Embedding -from refiners.fluxion.layers.converter import Converter -from refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d +from refiners.fluxion.layers.sampling import Downsample, Interpolate, Upsample __all__ = [ "Embedding", diff --git a/src/refiners/fluxion/layers/activations.py b/src/refiners/fluxion/layers/activations.py index eca9afd..1c0df61 100644 --- a/src/refiners/fluxion/layers/activations.py +++ b/src/refiners/fluxion/layers/activations.py @@ -1,7 +1,10 @@ -from refiners.fluxion.layers.module import Module -from torch.nn.functional import silu from torch import Tensor, sigmoid -from torch.nn.functional import gelu # type: ignore +from torch.nn.functional import ( + gelu, # type: ignore + silu, +) + +from refiners.fluxion.layers.module import Module class Activation(Module): diff --git a/src/refiners/fluxion/layers/attentions.py b/src/refiners/fluxion/layers/attentions.py index 2618b3e..de9e225 100644 --- a/src/refiners/fluxion/layers/attentions.py +++ b/src/refiners/fluxion/layers/attentions.py @@ -2,14 +2,14 @@ import math import torch from jaxtyping import Float -from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore from torch import Tensor, device as Device, dtype as DType +from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore +from refiners.fluxion.context import Contexts +from refiners.fluxion.layers.basics import Identity +from refiners.fluxion.layers.chain import Chain, Distribute, Lambda, Parallel from refiners.fluxion.layers.linear import Linear from refiners.fluxion.layers.module import Module -from refiners.fluxion.layers.chain import Chain, Distribute, Parallel, Lambda -from refiners.fluxion.layers.basics import Identity -from refiners.fluxion.context import Contexts def scaled_dot_product_attention( diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index cad3355..5f15050 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -1,8 +1,9 @@ -from refiners.fluxion.layers.module import Module, WeightedModule import torch -from torch import randn, Tensor, Size, device as Device, dtype as DType +from torch import Size, Tensor, device as Device, dtype as DType, randn from torch.nn import Parameter as TorchParameter +from refiners.fluxion.layers.module import Module, WeightedModule + class Identity(Module): def __init__(self) -> None: diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index d67cd6e..497cac9 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -1,15 +1,16 @@ -from collections import defaultdict import inspect import re import sys import traceback +from collections import defaultdict from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload + import torch from torch import Tensor, cat, device as Device, dtype as DType -from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule -from refiners.fluxion.context import Contexts, ContextProvider -from refiners.fluxion.utils import summarize_tensor +from refiners.fluxion.context import ContextProvider, Contexts +from refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule +from refiners.fluxion.utils import summarize_tensor T = TypeVar("T", bound=Module) TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10 diff --git a/src/refiners/fluxion/layers/conv.py b/src/refiners/fluxion/layers/conv.py index baab860..3309122 100644 --- a/src/refiners/fluxion/layers/conv.py +++ b/src/refiners/fluxion/layers/conv.py @@ -1,4 +1,5 @@ -from torch import nn, device as Device, dtype as DType +from torch import device as Device, dtype as DType, nn + from refiners.fluxion.layers.module import WeightedModule diff --git a/src/refiners/fluxion/layers/converter.py b/src/refiners/fluxion/layers/converter.py index 268bf5d..8826a68 100644 --- a/src/refiners/fluxion/layers/converter.py +++ b/src/refiners/fluxion/layers/converter.py @@ -1,6 +1,7 @@ -from refiners.fluxion.layers.module import ContextModule from torch import Tensor +from refiners.fluxion.layers.module import ContextModule + class Converter(ContextModule): """ diff --git a/src/refiners/fluxion/layers/embedding.py b/src/refiners/fluxion/layers/embedding.py index 7dccc3a..81eb6b1 100644 --- a/src/refiners/fluxion/layers/embedding.py +++ b/src/refiners/fluxion/layers/embedding.py @@ -1,8 +1,8 @@ -from refiners.fluxion.layers.module import WeightedModule -from torch.nn import Embedding as _Embedding -from torch import Tensor, device as Device, dtype as DType - from jaxtyping import Float, Int +from torch import Tensor, device as Device, dtype as DType +from torch.nn import Embedding as _Embedding + +from refiners.fluxion.layers.module import WeightedModule class Embedding(_Embedding, WeightedModule): # type: ignore diff --git a/src/refiners/fluxion/layers/linear.py b/src/refiners/fluxion/layers/linear.py index 1135572..5b65677 100644 --- a/src/refiners/fluxion/layers/linear.py +++ b/src/refiners/fluxion/layers/linear.py @@ -1,11 +1,10 @@ -from torch import device as Device, dtype as DType +from jaxtyping import Float +from torch import Tensor, device as Device, dtype as DType from torch.nn import Linear as _Linear -from torch import Tensor -from refiners.fluxion.layers.module import Module, WeightedModule + from refiners.fluxion.layers.activations import ReLU from refiners.fluxion.layers.chain import Chain - -from jaxtyping import Float +from refiners.fluxion.layers.module import Module, WeightedModule class Linear(_Linear, WeightedModule): diff --git a/src/refiners/fluxion/layers/maxpool.py b/src/refiners/fluxion/layers/maxpool.py index 60ffd0a..060f210 100644 --- a/src/refiners/fluxion/layers/maxpool.py +++ b/src/refiners/fluxion/layers/maxpool.py @@ -1,4 +1,5 @@ from torch import nn + from refiners.fluxion.layers.module import Module diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 4232476..1fc663e 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -1,17 +1,15 @@ -from collections import defaultdict -from inspect import signature, Parameter import sys +from collections import defaultdict +from inspect import Parameter, signature from pathlib import Path from types import ModuleType -from typing import Any, DefaultDict, Generator, TypeVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast from torch import device as Device, dtype as DType from torch.nn.modules.module import Module as TorchModule -from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.context import Context, ContextProvider - -from typing import TYPE_CHECKING, Sequence +from refiners.fluxion.utils import load_from_safetensors if TYPE_CHECKING: from refiners.fluxion.layers.chain import Chain diff --git a/src/refiners/fluxion/layers/norm.py b/src/refiners/fluxion/layers/norm.py index 97016cb..5ac6784 100644 --- a/src/refiners/fluxion/layers/norm.py +++ b/src/refiners/fluxion/layers/norm.py @@ -1,5 +1,6 @@ -from torch import nn, ones, zeros, Tensor, sqrt, device as Device, dtype as DType from jaxtyping import Float +from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros + from refiners.fluxion.layers.module import Module, WeightedModule diff --git a/src/refiners/fluxion/layers/padding.py b/src/refiners/fluxion/layers/padding.py index d8d7377..c65f472 100644 --- a/src/refiners/fluxion/layers/padding.py +++ b/src/refiners/fluxion/layers/padding.py @@ -1,4 +1,5 @@ from torch import nn + from refiners.fluxion.layers.module import Module diff --git a/src/refiners/fluxion/layers/pixelshuffle.py b/src/refiners/fluxion/layers/pixelshuffle.py index 003dafc..adcfc42 100644 --- a/src/refiners/fluxion/layers/pixelshuffle.py +++ b/src/refiners/fluxion/layers/pixelshuffle.py @@ -1,6 +1,7 @@ -from refiners.fluxion.layers.module import Module from torch.nn import PixelUnshuffle as _PixelUnshuffle +from refiners.fluxion.layers.module import Module + class PixelUnshuffle(_PixelUnshuffle, Module): def __init__(self, downscale_factor: int): diff --git a/src/refiners/fluxion/layers/sampling.py b/src/refiners/fluxion/layers/sampling.py index e5e75ff..d6368e3 100644 --- a/src/refiners/fluxion/layers/sampling.py +++ b/src/refiners/fluxion/layers/sampling.py @@ -1,11 +1,11 @@ -from refiners.fluxion.layers.chain import Chain, UseContext, SetContext -from refiners.fluxion.layers.conv import Conv2d +from torch import Size, Tensor, device as Device, dtype as DType +from torch.nn.functional import pad + from refiners.fluxion.layers.basics import Identity -from refiners.fluxion.layers.chain import Parallel, Lambda +from refiners.fluxion.layers.chain import Chain, Lambda, Parallel, SetContext, UseContext +from refiners.fluxion.layers.conv import Conv2d from refiners.fluxion.layers.module import Module from refiners.fluxion.utils import interpolate -from torch.nn.functional import pad -from torch import Tensor, Size, device as Device, dtype as DType class Downsample(Chain): diff --git a/src/refiners/fluxion/model_converter.py b/src/refiners/fluxion/model_converter.py index 4c827f4..ef14d02 100644 --- a/src/refiners/fluxion/model_converter.py +++ b/src/refiners/fluxion/model_converter.py @@ -1,10 +1,11 @@ from collections import defaultdict from enum import Enum, auto from pathlib import Path +from typing import Any, DefaultDict, TypedDict + +import torch from torch import Tensor, nn from torch.utils.hooks import RemovableHandle -import torch -from typing import Any, DefaultDict, TypedDict from refiners.fluxion.utils import norm, save_to_safetensors diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 3f5e710..7c4f5e0 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,15 +1,14 @@ -from typing import Iterable, Literal, TypeVar -from PIL import Image -from numpy import array, float32 from pathlib import Path +from typing import Iterable, Literal, TypeVar + +import torch +from jaxtyping import Float +from numpy import array, float32 +from PIL import Image from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore -from torch import norm as _norm, manual_seed as _manual_seed # type: ignore -import torch -from torch.nn.functional import pad as _pad, interpolate as _interpolate, conv2d # type: ignore -from torch import Tensor, device as Device, dtype as DType -from jaxtyping import Float - +from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore +from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore T = TypeVar("T") E = TypeVar("E") diff --git a/src/refiners/foundationals/clip/common.py b/src/refiners/foundationals/clip/common.py index 77a9a94..7a10949 100644 --- a/src/refiners/foundationals/clip/common.py +++ b/src/refiners/foundationals/clip/common.py @@ -1,4 +1,5 @@ from torch import Tensor, arange, device as Device, dtype as DType + import refiners.fluxion.layers as fl diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index 3380e5a..403b27f 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -1,13 +1,14 @@ +import re +from typing import cast + +import torch.nn.functional as F +from torch import Tensor, cat, zeros +from torch.nn import Parameter + +import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder from refiners.foundationals.clip.tokenizer import CLIPTokenizer -import refiners.fluxion.layers as fl -from typing import cast - -from torch import Tensor, cat, zeros -import torch.nn.functional as F -from torch.nn import Parameter -import re class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py index b969ed1..ed6db3d 100644 --- a/src/refiners/foundationals/clip/image_encoder.py +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -1,6 +1,7 @@ from torch import device as Device, dtype as DType + import refiners.fluxion.layers as fl -from refiners.foundationals.clip.common import PositionalEncoder, FeedForward +from refiners.foundationals.clip.common import FeedForward, PositionalEncoder class ClassToken(fl.Chain): diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 5a532c6..54d5790 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -1,6 +1,7 @@ from torch import device as Device, dtype as DType + import refiners.fluxion.layers as fl -from refiners.foundationals.clip.common import PositionalEncoder, FeedForward +from refiners.foundationals.clip.common import FeedForward, PositionalEncoder from refiners.foundationals.clip.tokenizer import CLIPTokenizer diff --git a/src/refiners/foundationals/clip/tokenizer.py b/src/refiners/foundationals/clip/tokenizer.py index 6f06531..585474d 100644 --- a/src/refiners/foundationals/clip/tokenizer.py +++ b/src/refiners/foundationals/clip/tokenizer.py @@ -1,11 +1,13 @@ import gzip -from pathlib import Path +import re from functools import lru_cache from itertools import islice -import re +from pathlib import Path + from torch import Tensor, tensor -from refiners.fluxion import pad + import refiners.fluxion.layers as fl +from refiners.fluxion import pad class CLIPTokenizer(fl.Module): diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index 42bdaf8..68b6517 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -1,27 +1,26 @@ -from refiners.foundationals.latent_diffusion.auto_encoder import ( - LatentDiffusionAutoencoder, -) from refiners.foundationals.clip.text_encoder import ( CLIPTextEncoderL, ) +from refiners.foundationals.latent_diffusion.auto_encoder import ( + LatentDiffusionAutoencoder, +) from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter -from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver +from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( - StableDiffusion_1, - StableDiffusion_1_Inpainting, - SD1UNet, SD1ControlnetAdapter, SD1IPAdapter, SD1T2IAdapter, + SD1UNet, + StableDiffusion_1, + StableDiffusion_1_Inpainting, ) from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( - SDXLUNet, DoubleTextEncoder, SDXLIPAdapter, SDXLT2IAdapter, + SDXLUNet, ) - __all__ = [ "StableDiffusion_1", "StableDiffusion_1_Inpainting", diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 3d0bdd3..2dc3bd5 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -1,20 +1,21 @@ +from PIL import Image +from torch import Tensor, device as Device, dtype as DType + from refiners.fluxion.context import Contexts from refiners.fluxion.layers import ( Chain, Conv2d, + Downsample, GroupNorm, Identity, - SiLU, - Downsample, - Upsample, - Sum, - SelfAttention2d, - Slicing, Residual, + SelfAttention2d, + SiLU, + Slicing, + Sum, + Upsample, ) from refiners.fluxion.utils import image_to_tensor, tensor_to_image -from torch import Tensor, device as Device, dtype as DType -from PIL import Image class Resnet(Sum): diff --git a/src/refiners/foundationals/latent_diffusion/cross_attention.py b/src/refiners/foundationals/latent_diffusion/cross_attention.py index 18fd7a6..7b9497e 100644 --- a/src/refiners/foundationals/latent_diffusion/cross_attention.py +++ b/src/refiners/foundationals/latent_diffusion/cross_attention.py @@ -1,24 +1,24 @@ -from torch import Tensor, Size, device as Device, dtype as DType +from torch import Size, Tensor, device as Device, dtype as DType from refiners.fluxion.context import Contexts from refiners.fluxion.layers import ( - Identity, - Flatten, - Unflatten, - Transpose, - Chain, - Parallel, - LayerNorm, - Attention, - UseContext, - Linear, GLU, + Attention, + Chain, + Conv2d, + Flatten, GeLU, GroupNorm, - Conv2d, + Identity, + LayerNorm, + Linear, + Parallel, + Residual, SelfAttention, SetContext, - Residual, + Transpose, + Unflatten, + UseContext, ) diff --git a/src/refiners/foundationals/latent_diffusion/freeu.py b/src/refiners/foundationals/latent_diffusion/freeu.py index d598d01..61726bd 100644 --- a/src/refiners/foundationals/latent_diffusion/freeu.py +++ b/src/refiners/foundationals/latent_diffusion/freeu.py @@ -1,13 +1,14 @@ import math from typing import Any, Generic, TypeVar -import refiners.fluxion.layers as fl import torch +from torch import Tensor +from torch.fft import fftn, fftshift, ifftn, ifftshift # type: ignore + +import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualConcatenator, SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet -from torch import Tensor -from torch.fft import fftn, fftshift, ifftn, ifftshift # type: ignore T = TypeVar("T", bound="SD1UNet | SDXLUNet") TSDFreeUAdapter = TypeVar("TSDFreeUAdapter", bound="SDFreeUAdapter[Any]") # Self (see PEP 673) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 72e1247..c25b82b 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,19 +1,19 @@ +import math from enum import IntEnum from functools import partial -from typing import Generic, TypeVar, Any, Callable, TYPE_CHECKING -import math +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar from jaxtyping import Float -from torch import Tensor, cat, softmax, zeros_like, device as Device, dtype as DType from PIL import Image +from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like +import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.lora import Lora -from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.fluxion.context import Contexts from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.utils import image_to_tensor, normalize -import refiners.fluxion.layers as fl +from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH if TYPE_CHECKING: from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index d0d52c3..bc041c8 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,23 +1,21 @@ from enum import Enum from pathlib import Path -from typing import Iterator, Callable +from typing import Callable, Iterator from torch import Tensor import refiners.fluxion.layers as fl -from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors - from refiners.fluxion.adapters.adapter import Adapter -from refiners.fluxion.adapters.lora import LoraAdapter, Lora - +from refiners.fluxion.adapters.lora import Lora, LoraAdapter +from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer -from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion import ( - StableDiffusion_1, - SD1UNet, CLIPTextEncoderL, LatentDiffusionAutoencoder, + SD1UNet, + StableDiffusion_1, ) +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet MODELS = ["unet", "text_encoder", "lda"] diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 7f5f5da..a283a0b 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod from typing import TypeVar -from torch import Tensor, device as Device, dtype as DType -from PIL import Image + import torch +from PIL import Image +from torch import Tensor, device as Device, dtype as DType + import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler diff --git a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py index 19a0a08..31deac0 100644 --- a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py @@ -8,7 +8,6 @@ from torch import Tensor, device as Device, dtype as DType from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel - MAX_STEPS = 1000 diff --git a/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py b/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py index 738a880..ab84fb9 100644 --- a/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py +++ b/src/refiners/foundationals/latent_diffusion/preprocessors/informative_drawings.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/carolineec/informative-drawings, MIT License from torch import device as Device, dtype as DType + import refiners.fluxion.layers as fl diff --git a/src/refiners/foundationals/latent_diffusion/range_adapter.py b/src/refiners/foundationals/latent_diffusion/range_adapter.py index 5b5fe97..317c13b 100644 --- a/src/refiners/foundationals/latent_diffusion/range_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/range_adapter.py @@ -1,9 +1,10 @@ import math -from torch import Tensor, arange, float32, exp, sin, cat, cos, device as Device, dtype as DType -from jaxtyping import Float, Int -from refiners.fluxion.adapters.adapter import Adapter +from jaxtyping import Float, Int +from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin + import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.adapter import Adapter def compute_sinusoidal_embedding( diff --git a/src/refiners/foundationals/latent_diffusion/reference_only_control.py b/src/refiners/foundationals/latent_diffusion/reference_only_control.py index 9e45f5c..5ef99de 100644 --- a/src/refiners/foundationals/latent_diffusion/reference_only_control.py +++ b/src/refiners/foundationals/latent_diffusion/reference_only_control.py @@ -1,18 +1,19 @@ +from torch import Tensor + +from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.layers import ( - Passthrough, - Lambda, Chain, Concatenate, - UseContext, + Identity, + Lambda, + Parallel, + Passthrough, SelfAttention, SetContext, - Identity, - Parallel, + UseContext, ) -from refiners.fluxion.adapters.adapter import Adapter from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock -from torch import Tensor class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]): diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py index 90ccc1a..5a9be28 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/__init__.py @@ -1,7 +1,7 @@ -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler -from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver -from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM +from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM +from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver +from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler __all__ = [ "Scheduler", diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index de2e866..200640d 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -1,4 +1,5 @@ -from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor +from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor + from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index fafdd58..01417e7 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,4 +1,5 @@ -from torch import Tensor, device as Device, randn, arange, Generator, tensor +from torch import Generator, Tensor, arange, device as Device, randn, tensor + from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index fe4d7e3..e1e11dc 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -1,8 +1,10 @@ -from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler -import numpy as np -from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype from collections import deque +import numpy as np +from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor + +from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler + class DPMSolver(Scheduler): """Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index e0d1fba..abf106c 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from enum import Enum -from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log from typing import TypeVar +from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt + T = TypeVar("T", bound="Scheduler") diff --git a/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py b/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py index 711e730..1916a58 100644 --- a/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py +++ b/src/refiners/foundationals/latent_diffusion/self_attention_guidance.py @@ -1,15 +1,15 @@ -from typing import Any, Generic, TypeVar, TYPE_CHECKING import math +from typing import TYPE_CHECKING, Any, Generic, TypeVar -from torch import Tensor, Size -from jaxtyping import Float import torch +from jaxtyping import Float +from torch import Size, Tensor -from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler +import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.context import Contexts -from refiners.fluxion.utils import interpolate, gaussian_blur -import refiners.fluxion.layers as fl +from refiners.fluxion.utils import gaussian_blur, interpolate +from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler if TYPE_CHECKING: from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py index 8b94b69..4064f4b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py @@ -1,11 +1,11 @@ -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet +from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( StableDiffusion_1, StableDiffusion_1_Inpainting, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter -from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet __all__ = [ "StableDiffusion_1", diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 99678dd..21e2f50 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -1,16 +1,18 @@ +from typing import Iterable, cast + +from torch import Tensor, device as Device, dtype as DType + +from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.context import Contexts -from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Slicing, Residual +from refiners.fluxion.layers import Chain, Conv2d, Lambda, Passthrough, Residual, SiLU, Slicing, UseContext +from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( - SD1UNet, DownBlocks, MiddleBlock, ResidualBlock, + SD1UNet, TimestepEncoder, ) -from refiners.fluxion.adapters.adapter import Adapter -from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d -from typing import cast, Iterable -from torch import Tensor, device as Device, dtype as DType class ConditionEncoder(Chain): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py index d1e3379..aa78cce 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/image_prompt.py @@ -2,7 +2,7 @@ from torch import Tensor from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler +from refiners.foundationals.latent_diffusion.image_prompt import ImageProjection, IPAdapter, PerceiverResampler from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 9f82bd3..532c68f 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -1,15 +1,16 @@ +import numpy as np import torch +from PIL import Image +from torch import Tensor, device as Device, dtype as DType + from refiners.fluxion.utils import image_to_tensor, interpolate from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter -from PIL import Image -import numpy as np -from torch import device as Device, dtype as DType, Tensor +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet class SD1Autoencoder(LatentDiffusionAutoencoder): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py index 7e37785..6b18302 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py @@ -1,6 +1,7 @@ -from dataclasses import field, dataclass -from torch import Tensor +from dataclasses import dataclass, field + from PIL import Image +from torch import Tensor from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/self_attention_guidance.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/self_attention_guidance.py index e92471e..674f396 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/self_attention_guidance.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/self_attention_guidance.py @@ -1,11 +1,11 @@ +import refiners.fluxion.layers as fl +from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.foundationals.latent_diffusion.self_attention_guidance import ( SAGAdapter, - SelfAttentionShape, SelfAttentionMap, + SelfAttentionShape, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, MiddleBlock, ResidualBlock -from refiners.fluxion.layers.attentions import ScaledDotProductAttention -import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import MiddleBlock, ResidualBlock, SD1UNet class SD1SAGAdapter(SAGAdapter[SD1UNet]): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py index 7cbb49c..c70b7ce 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py @@ -1,8 +1,8 @@ from torch import Tensor -from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, ResidualAccumulator import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator, SD1UNet +from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, T2IAdapter, T2IFeatures class SD1T2IAdapter(T2IAdapter[SD1UNet]): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index e2a8f89..3d8c967 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -1,12 +1,11 @@ -from typing import cast, Iterable +from typing import Iterable, cast from torch import Tensor, device as Device, dtype as DType -from refiners.fluxion.context import Contexts import refiners.fluxion.layers as fl - +from refiners.fluxion.context import Contexts from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder, RangeAdapter2d +from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder class TimestepEncoder(fl.Passthrough): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py index 775a0e1..43f2c36 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py @@ -1,9 +1,8 @@ -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter - +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet __all__ = [ "SDXLUNet", diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py index 93d9e8b..74d1372 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py @@ -2,7 +2,7 @@ from torch import Tensor from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion.image_prompt import IPAdapter, ImageProjection, PerceiverResampler +from refiners.foundationals.latent_diffusion.image_prompt import ImageProjection, IPAdapter, PerceiverResampler from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index b8b52e6..0cb979b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -1,12 +1,13 @@ import torch +from torch import Tensor, device as Device, dtype as DType + from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder -from torch import device as Device, dtype as DType, Tensor +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet class SDXLAutoencoder(LatentDiffusionAutoencoder): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/self_attention_guidance.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/self_attention_guidance.py index d125156..985687f 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/self_attention_guidance.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/self_attention_guidance.py @@ -1,11 +1,11 @@ +import refiners.fluxion.layers as fl +from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.foundationals.latent_diffusion.self_attention_guidance import ( SAGAdapter, - SelfAttentionShape, SelfAttentionMap, + SelfAttentionShape, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet, MiddleBlock, ResidualBlock -from refiners.fluxion.layers.attentions import ScaledDotProductAttention -import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, ResidualBlock, SDXLUNet class SDXLSAGAdapter(SAGAdapter[SDXLUNet]): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py index cab08d6..e955422 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py @@ -1,9 +1,9 @@ from torch import Tensor -from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL -from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator +from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet +from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures class SDXLT2IAdapter(T2IAdapter[SDXLUNet]): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 12b32db..72a6245 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -1,11 +1,12 @@ from typing import cast -from torch import device as Device, dtype as DType, Tensor, cat + +from jaxtyping import Float +from torch import Tensor, cat, device as Device, dtype as DType + +import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.context import Contexts -import refiners.fluxion.layers as fl from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL -from jaxtyping import Float - from refiners.foundationals.clip.tokenizer import CLIPTokenizer diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 4edf8c1..c2b49c8 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -1,18 +1,20 @@ from typing import cast + from torch import Tensor, device as Device, dtype as DType -from refiners.fluxion.context import Contexts + import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( - ResidualAccumulator, - ResidualBlock, - ResidualConcatenator, -) from refiners.foundationals.latent_diffusion.range_adapter import ( RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding, ) +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( + ResidualAccumulator, + ResidualBlock, + ResidualConcatenator, +) class TextTimeEmbedding(fl.Chain): diff --git a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py index 1efdb52..ba49f5e 100644 --- a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py @@ -1,12 +1,12 @@ -from typing import Generic, TypeVar, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Generic, TypeVar from torch import Tensor, device as Device, dtype as DType from torch.nn import AvgPool2d as _AvgPool2d +import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.context import Contexts from refiners.fluxion.layers.module import Module -import refiners.fluxion.layers as fl if TYPE_CHECKING: from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet diff --git a/src/refiners/foundationals/segment_anything/image_encoder.py b/src/refiners/foundationals/segment_anything/image_encoder.py index 1e02e33..118731e 100644 --- a/src/refiners/foundationals/segment_anything/image_encoder.py +++ b/src/refiners/foundationals/segment_anything/image_encoder.py @@ -1,9 +1,9 @@ -from torch import device as Device, dtype as DType, Tensor -from refiners.fluxion.context import Contexts -import refiners.fluxion.layers as fl -from refiners.fluxion.utils import pad -from torch import nn import torch +from torch import Tensor, device as Device, dtype as DType, nn + +import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts +from refiners.fluxion.utils import pad class PatchEncoder(fl.Chain): diff --git a/src/refiners/foundationals/segment_anything/mask_decoder.py b/src/refiners/foundationals/segment_anything/mask_decoder.py index 5502bca..b0ee47d 100644 --- a/src/refiners/foundationals/segment_anything/mask_decoder.py +++ b/src/refiners/foundationals/segment_anything/mask_decoder.py @@ -1,12 +1,12 @@ -import refiners.fluxion.layers as fl -from torch import device as Device, dtype as DType, Tensor, nn import torch +from torch import Tensor, device as Device, dtype as DType, nn +import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts from refiners.foundationals.segment_anything.transformer import ( SparseCrossDenseAttention, TwoWayTranformerLayer, ) -from refiners.fluxion.context import Contexts class EmbeddingsAggregator(fl.ContextModule): diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 1c841c2..6f83f72 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import Sequence -from PIL import Image -from torch import device as Device, dtype as DType, Tensor + import numpy as np import torch +from PIL import Image +from torch import Tensor, device as Device, dtype as DType + import refiners.fluxion.layers as fl -from refiners.fluxion.utils import image_to_tensor, normalize, pad, interpolate +from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py index c803d46..222ae4e 100644 --- a/src/refiners/foundationals/segment_anything/prompt_encoder.py +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -1,8 +1,10 @@ -from enum import Enum, auto from collections.abc import Sequence -from torch import device as Device, dtype as DType, Tensor, nn +from enum import Enum, auto + import torch from jaxtyping import Float, Int +from torch import Tensor, device as Device, dtype as DType, nn + import refiners.fluxion.layers as fl from refiners.fluxion.context import Contexts diff --git a/src/refiners/foundationals/segment_anything/transformer.py b/src/refiners/foundationals/segment_anything/transformer.py index 4c72bab..5fb13d7 100644 --- a/src/refiners/foundationals/segment_anything/transformer.py +++ b/src/refiners/foundationals/segment_anything/transformer.py @@ -1,4 +1,5 @@ -from torch import dtype as DType, device as Device +from torch import device as Device, dtype as DType + import refiners.fluxion.layers as fl diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index 66971aa..f4c9950 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -1,7 +1,8 @@ +import sys from importlib import import_module from importlib.metadata import requires + from packaging.requirements import Requirement -import sys refiners_requires = requires("refiners") assert refiners_requires is not None diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index b2ddd4d..7f9bdef 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -1,7 +1,8 @@ -from typing import TYPE_CHECKING, Generic, Iterable, Any, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Iterable, TypeVar + +from loguru import logger from torch import tensor from torch.nn import Parameter -from loguru import logger if TYPE_CHECKING: from refiners.training_utils.config import BaseConfig diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 6c21461..fdf979a 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -1,17 +1,18 @@ +from enum import Enum from logging import warn from pathlib import Path from typing import Any, Callable, Iterable, Literal, Type, TypeVar -from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version -from torch.optim import AdamW, SGD, Optimizer, Adam -from torch.nn import Parameter -from enum import Enum -from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore -from pydantic import BaseModel, validator -import tomli -import refiners.fluxion.layers as fl -from prodigyopt import Prodigy # type: ignore -from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout +import tomli +from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore +from prodigyopt import Prodigy # type: ignore +from pydantic import BaseModel, validator +from torch.nn import Parameter +from torch.optim import SGD, Adam, AdamW, Optimizer +from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version + +import refiners.fluxion.layers as fl +from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout __all__ = [ "parse_number_unit_field", diff --git a/src/refiners/training_utils/dropout.py b/src/refiners/training_utils/dropout.py index 90999ac..37c188e 100644 --- a/src/refiners/training_utils/dropout.py +++ b/src/refiners/training_utils/dropout.py @@ -1,11 +1,11 @@ from typing import TYPE_CHECKING, Any, TypeVar -from torch import Tensor, randint, cat, rand +from torch import Tensor, cat, rand, randint from torch.nn import Dropout as TorchDropout import refiners.fluxion.layers as fl -from refiners.training_utils.callback import Callback from refiners.fluxion.adapters.adapter import Adapter +from refiners.training_utils.callback import Callback if TYPE_CHECKING: from refiners.training_utils.config import BaseConfig diff --git a/src/refiners/training_utils/huggingface_datasets.py b/src/refiners/training_utils/huggingface_datasets.py index 956ad0f..5b41a33 100644 --- a/src/refiners/training_utils/huggingface_datasets.py +++ b/src/refiners/training_utils/huggingface_datasets.py @@ -1,6 +1,7 @@ -from datasets import load_dataset as _load_dataset, VerificationMode # type: ignore from typing import Any, Generic, Protocol, TypeVar, cast +from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore + __all__ = ["load_hf_dataset", "HuggingfaceDataset"] diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 187a1fb..200daaf 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -1,29 +1,31 @@ +import random from dataclasses import dataclass -from typing import Any, TypeVar, TypedDict, Callable -from pydantic import BaseModel -from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat -from loguru import logger -from torch.utils.data import Dataset -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore -import refiners.fluxion.layers as fl -from PIL import Image from functools import cached_property -from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder -from refiners.training_utils.config import BaseConfig +from typing import Any, Callable, TypedDict, TypeVar + +from loguru import logger +from PIL import Image +from pydantic import BaseModel +from torch import Generator, Tensor, cat, device as Device, dtype as DType, randn +from torch.nn import Module +from torch.nn.functional import mse_loss +from torch.utils.data import Dataset +from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore + +import refiners.fluxion.layers as fl +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion import ( - StableDiffusion_1, DPMSolver, SD1UNet, + StableDiffusion_1, ) from refiners.foundationals.latent_diffusion.schedulers import DDPM -from torch.nn.functional import mse_loss -import random -from refiners.training_utils.wandb import WandbLoggable -from refiners.training_utils.trainer import Trainer +from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.training_utils.callback import Callback -from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset -from torch.nn import Module +from refiners.training_utils.config import BaseConfig +from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, load_hf_dataset +from refiners.training_utils.trainer import Trainer +from refiners.training_utils.wandb import WandbLoggable class LatentDiffusionConfig(BaseModel): diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index ea079d0..4ccb014 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -1,41 +1,43 @@ -from functools import cached_property, wraps -from pathlib import Path import random import time +from functools import cached_property, wraps +from pathlib import Path +from typing import Any, Callable, Generic, Iterable, TypeVar, cast + import numpy as np -from torch import device as Device, Tensor, get_rng_state, no_grad, set_rng_state, cuda, stack +from loguru import logger +from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack +from torch.autograd import backward from torch.nn import Parameter from torch.optim import Optimizer +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + CosineAnnealingWarmRestarts, + CyclicLR, + ExponentialLR, + LambdaLR, + LRScheduler, + MultiplicativeLR, + MultiStepLR, + OneCycleLR, + ReduceLROnPlateau, + StepLR, +) from torch.utils.data import DataLoader, Dataset -from torch.autograd import backward -from typing import Any, Callable, Generic, Iterable, TypeVar, cast -from loguru import logger + from refiners.fluxion import layers as fl from refiners.fluxion.utils import manual_seed -from refiners.training_utils.wandb import WandbLogger, WandbLoggable -from refiners.training_utils.config import BaseConfig, TimeUnit, TimeValue, SchedulerType -from refiners.training_utils.dropout import DropoutCallback from refiners.training_utils.callback import ( Callback, ClockCallback, GradientNormClipping, - GradientValueClipping, GradientNormLogging, + GradientValueClipping, MonitorLoss, ) -from torch.optim.lr_scheduler import ( - StepLR, - ExponentialLR, - ReduceLROnPlateau, - CosineAnnealingLR, - LambdaLR, - OneCycleLR, - LRScheduler, - MultiplicativeLR, - CosineAnnealingWarmRestarts, - CyclicLR, - MultiStepLR, -) +from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue +from refiners.training_utils.dropout import DropoutCallback +from refiners.training_utils.wandb import WandbLoggable, WandbLogger __all__ = ["seed_everything", "scoped_seed", "Trainer"] diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index c093aba..2683f89 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -1,4 +1,5 @@ from typing import Any + import wandb from PIL import Image diff --git a/tests/adapters/test_adapter.py b/tests/adapters/test_adapter.py index daa4ef1..8049f97 100644 --- a/tests/adapters/test_adapter.py +++ b/tests/adapters/test_adapter.py @@ -1,4 +1,5 @@ import pytest + from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.layers import Chain, Linear diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py index 8d25736..bdedfeb 100644 --- a/tests/adapters/test_lora.py +++ b/tests/adapters/test_lora.py @@ -1,6 +1,7 @@ -from refiners.fluxion.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter -from torch import randn, allclose +from torch import allclose, randn + import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.lora import Lora, LoraAdapter, SingleLoraAdapter def test_single_lora_adapter() -> None: diff --git a/tests/adapters/test_range_adapter.py b/tests/adapters/test_range_adapter.py index ede27d4..90e3bdc 100644 --- a/tests/adapters/test_range_adapter.py +++ b/tests/adapters/test_range_adapter.py @@ -1,7 +1,8 @@ import torch + from refiners.fluxion.adapters.adapter import Adapter -from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder from refiners.fluxion.layers import Chain, Linear +from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder class DummyLinearAdapter(Chain, Adapter[Linear]): diff --git a/tests/conftest.py b/tests/conftest.py index b57459d..d1403ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import os -import torch from pathlib import Path + +import torch from pytest import fixture PARENT_PATH = Path(__file__).parent diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index dcbf30e..201c1b0 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1,34 +1,32 @@ -import torch -import pytest - -from typing import Iterator - -from warnings import warn -from PIL import Image from pathlib import Path +from typing import Iterator +from warnings import warn -from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed +import pytest +import torch +from PIL import Image + +from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed +from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion import ( - StableDiffusion_1, - StableDiffusion_1_Inpainting, - SD1UNet, SD1ControlnetAdapter, SD1IPAdapter, SD1T2IAdapter, + SD1UNet, + SDFreeUAdapter, SDXLIPAdapter, SDXLT2IAdapter, - SDFreeUAdapter, + StableDiffusion_1, + StableDiffusion_1_Inpainting, ) from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget +from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.schedulers import DDIM -from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter -from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL - from tests.utils import ensure_similar_images diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index 7dd9ddc..69131b2 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -1,13 +1,12 @@ -import torch -import pytest - -from warnings import warn -from PIL import Image from pathlib import Path +from warnings import warn + +import pytest +import torch +from PIL import Image from refiners.fluxion.utils import image_to_tensor, tensor_to_image from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings - from tests.utils import ensure_similar_images diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index b678716..5b5cd1f 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -1,5 +1,6 @@ import pytest import torch + import refiners.fluxion.layers as fl from refiners.fluxion.context import Contexts diff --git a/tests/fluxion/layers/test_converter.py b/tests/fluxion/layers/test_converter.py index 611d4fa..8a33188 100644 --- a/tests/fluxion/layers/test_converter.py +++ b/tests/fluxion/layers/test_converter.py @@ -1,7 +1,8 @@ -import torch -import pytest from warnings import warn +import pytest +import torch + import refiners.fluxion.layers as fl from refiners.fluxion.layers.chain import ChainError, Distribute diff --git a/tests/fluxion/test_model_converter.py b/tests/fluxion/test_model_converter.py index 2e5936c..657ce0a 100644 --- a/tests/fluxion/test_model_converter.py +++ b/tests/fluxion/test_model_converter.py @@ -1,10 +1,11 @@ # pyright: reportPrivateUsage=false import pytest import torch -from torch import nn, Tensor -from refiners.fluxion.utils import manual_seed -from refiners.fluxion.model_converter import ModelConverter, ConversionStage +from torch import Tensor, nn + import refiners.fluxion.layers as fl +from refiners.fluxion.model_converter import ConversionStage, ModelConverter +from refiners.fluxion.utils import manual_seed class CustomBasicLayer1(fl.Module): diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index e83b789..8811c47 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from warnings import warn -from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore -from torch import device as Device, dtype as DType -from PIL import Image import pytest import torch +from PIL import Image +from torch import device as Device, dtype as DType +from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index d297f01..ed86561 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -1,18 +1,16 @@ -import torch -import pytest - -from warnings import warn from pathlib import Path +from warnings import warn +import pytest +import torch +import transformers # type: ignore +from diffusers import StableDiffusionPipeline # type: ignore + +import refiners.fluxion.layers as fl +from refiners.fluxion.utils import load_from_safetensors from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.fluxion.utils import load_from_safetensors -import refiners.fluxion.layers as fl - -from diffusers import StableDiffusionPipeline # type: ignore -import transformers # type: ignore - PROMPTS = [ "a cute cat", # a simple prompt diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py index 459fa41..3aac668 100644 --- a/tests/foundationals/clip/test_image_encoder.py +++ b/tests/foundationals/clip/test_image_encoder.py @@ -1,13 +1,12 @@ -import torch -import pytest - -from warnings import warn from pathlib import Path +from warnings import warn +import pytest +import torch from transformers import CLIPVisionModelWithProjection # type: ignore -from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.fluxion.utils import load_from_safetensors +from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH @pytest.fixture(scope="module") diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index eeb501a..0e108b7 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -1,15 +1,13 @@ -import torch -import pytest - -from warnings import warn from pathlib import Path +from warnings import warn -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from refiners.fluxion.utils import load_from_safetensors - +import pytest +import torch import transformers # type: ignore -from refiners.foundationals.clip.tokenizer import CLIPTokenizer +from refiners.fluxion.utils import load_from_safetensors +from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL +from refiners.foundationals.clip.tokenizer import CLIPTokenizer long_prompt = """ Above these apparent hieroglyphics was a figure of evidently pictorial intent, diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index 287dcc3..2ddca24 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -1,15 +1,14 @@ -import torch -import pytest - -from warnings import warn -from PIL import Image from pathlib import Path +from warnings import warn + +import pytest +import torch +from PIL import Image +from tests.utils import ensure_similar_images from refiners.fluxion.utils import load_from_safetensors from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder -from tests.utils import ensure_similar_images - @pytest.fixture(scope="module") def ref_path() -> Path: diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py index 92f7662..36f3b04 100644 --- a/tests/foundationals/latent_diffusion/test_controlnet.py +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -1,11 +1,11 @@ from typing import Iterator -import torch import pytest +import torch import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import lookup_top_adapter -from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter +from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py index b214084..6b7001b 100644 --- a/tests/foundationals/latent_diffusion/test_freeu.py +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -3,9 +3,9 @@ from typing import Iterator import pytest import torch -from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet -from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter, FreeUResidualConcatenator from refiners.fluxion import manual_seed +from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet +from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter @pytest.fixture(scope="module", params=[True, False]) diff --git a/tests/foundationals/latent_diffusion/test_reference_only_control.py b/tests/foundationals/latent_diffusion/test_reference_only_control.py index 95201bd..68833b3 100644 --- a/tests/foundationals/latent_diffusion/test_reference_only_control.py +++ b/tests/foundationals/latent_diffusion/test_reference_only_control.py @@ -1,15 +1,14 @@ -import torch import pytest - +import torch from refiners.foundationals.latent_diffusion import SD1UNet +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock from refiners.foundationals.latent_diffusion.reference_only_control import ( ReferenceOnlyControlAdapter, SaveLayerNormAdapter, SelfAttentionInjectionAdapter, SelfAttentionInjectionPassthrough, ) -from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock @torch.no_grad() diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index a7f357f..553171a 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -1,9 +1,11 @@ -import pytest from typing import cast from warnings import warn -from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, DDIM + +import pytest +from torch import Tensor, allclose, device as Device, randn + from refiners.fluxion import manual_seed -from torch import randn, Tensor, allclose, device as Device +from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver def test_dpm_solver_diffusers(): diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index 4ff3551..cb51253 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -1,12 +1,13 @@ -from typing import Any, Protocol, cast from pathlib import Path +from typing import Any, Protocol, cast from warnings import warn + import pytest import torch from torch import Tensor -from refiners.fluxion.utils import manual_seed import refiners.fluxion.layers as fl +from refiners.fluxion.utils import manual_seed from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder diff --git a/tests/foundationals/latent_diffusion/test_sdxl_unet.py b/tests/foundationals/latent_diffusion/test_sdxl_unet.py index a79fbba..95b031b 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_unet.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_unet.py @@ -1,12 +1,13 @@ -from typing import Any from pathlib import Path +from typing import Any from warnings import warn + import pytest import torch -from refiners.fluxion.utils import manual_seed -from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet from refiners.fluxion.model_converter import ConversionStage, ModelConverter +from refiners.fluxion.utils import manual_seed +from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet @pytest.fixture(scope="module") diff --git a/tests/foundationals/latent_diffusion/test_unet.py b/tests/foundationals/latent_diffusion/test_unet.py index b56e0e1..4fe09e3 100644 --- a/tests/foundationals/latent_diffusion/test_unet.py +++ b/tests/foundationals/latent_diffusion/test_unet.py @@ -1,7 +1,8 @@ -from refiners.foundationals.latent_diffusion import SD1UNet -from refiners.fluxion import manual_seed import torch +from refiners.fluxion import manual_seed +from refiners.foundationals.latent_diffusion import SD1UNet + def test_unet_context_flush(): manual_seed(0) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index dfc62ac..1e00e64 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -3,26 +3,25 @@ from pathlib import Path from typing import cast from warnings import warn +import numpy as np import pytest import torch import torch.nn as nn -import numpy as np - from PIL import Image -from torch import Tensor -from refiners.fluxion import manual_seed -from refiners.fluxion.model_converter import ModelConverter - -from refiners.fluxion.utils import image_to_tensor -from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention -from refiners.foundationals.segment_anything.model import SegmentAnythingH -from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer from tests.foundationals.segment_anything.utils import ( FacebookSAM, FacebookSAMPredictor, SAMPrompt, intersection_over_union, ) +from torch import Tensor + +from refiners.fluxion import manual_seed +from refiners.fluxion.model_converter import ModelConverter +from refiners.fluxion.utils import image_to_tensor +from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention +from refiners.foundationals.segment_anything.model import SegmentAnythingH +from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer # See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported) PROMPTS: list[SAMPrompt] = [ @@ -235,9 +234,10 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N point_embedding = torch.randn(1, 3, 256, device=facebook_sam_h.device) mask_embedding = torch.randn(1, 256, 64, 64, device=facebook_sam_h.device) - import refiners.fluxion.layers as fl from segment_anything.modeling.common import LayerNorm2d # type: ignore + import refiners.fluxion.layers as fl + assert issubclass(LayerNorm2d, nn.Module) custom_layers = {LayerNorm2d: fl.LayerNorm2d} diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index 274726c..ef73e36 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -2,11 +2,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, TypedDict -from jaxtyping import Bool -from torch import Tensor, nn import numpy as np import numpy.typing as npt import torch +from jaxtyping import Bool +from torch import Tensor, nn NDArrayUInt8 = npt.NDArray[np.uint8] NDArray = npt.NDArray[Any] diff --git a/tests/utils.py b/tests/utils.py index 0cf86ea..a89a86f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,6 @@ -import torch -import piq # type: ignore import numpy as np +import piq # type: ignore +import torch from PIL import Image