diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index 42bda93..3f24d0e 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -6,9 +6,11 @@ from torch import nn from diffusers import ControlNetModel # type: ignore from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.model_converter import ModelConverter -from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Controlnet -from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver -from refiners.foundationals.latent_diffusion import SD1UNet +from refiners.foundationals.latent_diffusion import ( + SD1UNet, + SD1Controlnet, + DPMSolver, +) class Args(argparse.Namespace): diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py index b7b568f..2942393 100644 --- a/scripts/conversion/convert_diffusers_lora.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -9,7 +9,7 @@ from diffusers import DiffusionPipeline # type: ignore import refiners.fluxion.layers as fl from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet +from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target from refiners.adapters.lora import Lora diff --git a/scripts/conversion/convert_refiners_lora_to_sdwebui.py b/scripts/conversion/convert_refiners_lora_to_sdwebui.py index f263809..96fb413 100644 --- a/scripts/conversion/convert_refiners_lora_to_sdwebui.py +++ b/scripts/conversion/convert_refiners_lora_to_sdwebui.py @@ -12,8 +12,8 @@ from convert_transformers_clip_text_model import ( Args as TextEncoderConversionArgs, ) from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL +from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget -from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet import refiners.fluxion.layers as fl diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index 7d3ab1a..cd2d6d1 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -5,20 +5,14 @@ from refiners.foundationals.clip.text_encoder import ( CLIPTextEncoderL, ) from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver -from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( +from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( StableDiffusion_1, StableDiffusion_1_Inpainting, -) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( SD1UNet, -) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import ( SD1Controlnet, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import ( +from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( SDXLUNet, -) -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import ( DoubleTextEncoder, ) diff --git a/src/refiners/foundationals/latent_diffusion/self_attention_injection.py b/src/refiners/foundationals/latent_diffusion/self_attention_injection.py index 4a4498b..2f02b5b 100644 --- a/src/refiners/foundationals/latent_diffusion/self_attention_injection.py +++ b/src/refiners/foundationals/latent_diffusion/self_attention_injection.py @@ -10,7 +10,7 @@ from refiners.fluxion.layers import ( Parallel, ) from refiners.adapters.adapter import Adapter -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet +from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock from torch import Tensor diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 0da8243..f4f6d03 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -1,20 +1,22 @@ from dataclasses import dataclass from typing import Any, TypeVar, TypedDict, cast from pydantic import BaseModel -from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat from loguru import logger from torch.utils.data import Dataset -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from torchvision.transforms import RandomCrop # type: ignore import refiners.fluxion.layers as fl from PIL import Image from functools import cached_property from refiners.training_utils.config import BaseConfig -from refiners.foundationals.latent_diffusion import StableDiffusion_1 -from refiners.foundationals.latent_diffusion.schedulers import DPMSolver +from refiners.foundationals.latent_diffusion import ( + StableDiffusion_1, + DPMSolver, + SD1UNet, + LatentDiffusionAutoencoder, +) +from refiners.foundationals.latent_diffusion.schedulers import DDPM from torch.nn.functional import mse_loss import random from refiners.training_utils.wandb import WandbLoggable diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4a2ed9d..efd9649 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -8,9 +8,12 @@ from PIL import Image from pathlib import Path from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed -from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet +from refiners.foundationals.latent_diffusion import ( + StableDiffusion_1, + StableDiffusion_1_Inpainting, + SD1UNet, + SD1Controlnet, +) from refiners.foundationals.latent_diffusion.lora import LoraWeights from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection diff --git a/tests/foundationals/latent_diffusion/test_unet.py b/tests/foundationals/latent_diffusion/test_unet.py index a4af0f5..568afc5 100644 --- a/tests/foundationals/latent_diffusion/test_unet.py +++ b/tests/foundationals/latent_diffusion/test_unet.py @@ -1,4 +1,4 @@ -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet +from refiners.foundationals.latent_diffusion import SD1UNet from refiners.fluxion import manual_seed import torch