From b933fabf3162f4b18735fb0810aacd20daac4368 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Thu, 31 Aug 2023 17:22:57 +0200 Subject: [PATCH] unet: get rid of clip_embedding attribute for SD1 It is implicitly defined by the underlying cross-attention layer. This also makes it consistent with SDXL. --- scripts/conversion/convert_diffusers_controlnet.py | 2 +- scripts/conversion/convert_diffusers_lora.py | 2 +- scripts/conversion/convert_diffusers_unet.py | 4 +--- .../conversion/convert_refiners_lora_to_sdwebui.py | 2 +- .../latent_diffusion/stable_diffusion_1/model.py | 2 +- .../latent_diffusion/stable_diffusion_1/unet.py | 11 ++--------- src/refiners/training_utils/latent_diffusion.py | 2 +- tests/e2e/test_diffusion.py | 4 ++-- .../foundationals/latent_diffusion/test_controlnet.py | 2 +- .../latent_diffusion/test_reference_only_control.py | 2 +- tests/foundationals/latent_diffusion/test_unet.py | 2 +- 11 files changed, 13 insertions(+), 22 deletions(-) diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index 50935ed..faba547 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -21,7 +21,7 @@ class Args(argparse.Namespace): @torch.no_grad() def convert(args: Args) -> dict[str, torch.Tensor]: controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore - unet = SD1UNet(in_channels=4, clip_embedding_dim=768) + unet = SD1UNet(in_channels=4) adapter = SD1ControlnetAdapter(unet, name="mycn").inject() controlnet = unet.Controlnet diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py index 3afde74..07d8fdc 100644 --- a/scripts/conversion/convert_diffusers_lora.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -44,7 +44,7 @@ def process(args: Args) -> None: diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore - refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768) + refiners_model = SD1UNet(in_channels=4) target = LoraTarget.CrossAttention metadata = {"unet_targets": "CrossAttentionBlock2d"} rank = diffusers_state_dict[ diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index fea8f2a..5748b6b 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -22,9 +22,7 @@ def setup_converter(args: Args) -> ModelConverter: source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore target = ( - SDXLUNet(in_channels=source_in_channels) - if source_has_time_ids - else SD1UNet(in_channels=source_in_channels, clip_embedding_dim=source_clip_embedding_dim) + SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels) ) x = torch.randn(1, source_in_channels, 32, 32) diff --git a/scripts/conversion/convert_refiners_lora_to_sdwebui.py b/scripts/conversion/convert_refiners_lora_to_sdwebui.py index 96fb413..839026c 100644 --- a/scripts/conversion/convert_refiners_lora_to_sdwebui.py +++ b/scripts/conversion/convert_refiners_lora_to_sdwebui.py @@ -62,7 +62,7 @@ def main() -> None: for meta_key, meta_value in metadata.items(): match meta_key: case "unet_targets": - model = SD1UNet(in_channels=4, clip_embedding_dim=768) + model = SD1UNet(in_channels=4) create_mapping = partial(get_unet_mapping, source_path=args.sd15) key_prefix = "unet." lora_prefix = "lora_unet_" diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index ad851bc..21a1487 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -24,7 +24,7 @@ class StableDiffusion_1(LatentDiffusionModel): device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: - unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768) + unet = unet or SD1UNet(in_channels=4) lda = lda or LatentDiffusionAutoencoder() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() scheduler = scheduler or DPMSolver(num_inference_steps=30) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 894d839..2970bf2 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -243,17 +243,10 @@ class ResidualConcatenator(fl.Chain): class SD1UNet(fl.Chain): - structural_attrs = ["in_channels", "clip_embedding_dim"] + structural_attrs = ["in_channels"] - def __init__( - self, - in_channels: int, - clip_embedding_dim: int, - device: Device | str | None = None, - dtype: DType | None = None, - ): + def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: self.in_channels = in_channels - self.clip_embedding_dim = clip_embedding_dim super().__init__( TimestepEncoder(device=device, dtype=dtype), DownBlocks(in_channels=in_channels, device=device, dtype=dtype), diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 2caca6f..bdd75be 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -128,7 +128,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): @cached_property def unet(self) -> SD1UNet: assert self.config.models["unet"] is not None, "The config must contain a unet entry." - return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device) + return SD1UNet(in_channels=4, device=self.device).to(device=self.device) @cached_property def text_encoder(self) -> CLIPTextEncoderL: diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 411589e..c728af2 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -224,7 +224,7 @@ def sd15_inpainting( warn("not running on CPU, skipping") pytest.skip() - unet = SD1UNet(in_channels=9, clip_embedding_dim=768) + unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device) sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) @@ -242,7 +242,7 @@ def sd15_inpainting_float16( warn("not running on CPU, skipping") pytest.skip() - unet = SD1UNet(in_channels=9, clip_embedding_dim=768) + unet = SD1UNet(in_channels=9) sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16) sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights)) diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py index 58f4050..b89b384 100644 --- a/tests/foundationals/latent_diffusion/test_controlnet.py +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet impor @pytest.fixture(scope="module", params=[True, False]) def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]: with_parent: bool = request.param - unet = SD1UNet(in_channels=9, clip_embedding_dim=768) + unet = SD1UNet(in_channels=9) if with_parent: fl.Chain(unet) yield unet diff --git a/tests/foundationals/latent_diffusion/test_reference_only_control.py b/tests/foundationals/latent_diffusion/test_reference_only_control.py index dc0ab7b..580fdb0 100644 --- a/tests/foundationals/latent_diffusion/test_reference_only_control.py +++ b/tests/foundationals/latent_diffusion/test_reference_only_control.py @@ -14,7 +14,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti @torch.no_grad() def test_sai_inject_eject() -> None: - unet = SD1UNet(in_channels=9, clip_embedding_dim=768) + unet = SD1UNet(in_channels=9) sai = ReferenceOnlyControlAdapter(unet) nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock))) diff --git a/tests/foundationals/latent_diffusion/test_unet.py b/tests/foundationals/latent_diffusion/test_unet.py index 568afc5..b56e0e1 100644 --- a/tests/foundationals/latent_diffusion/test_unet.py +++ b/tests/foundationals/latent_diffusion/test_unet.py @@ -9,7 +9,7 @@ def test_unet_context_flush(): timestep = torch.randint(0, 999, size=(1, 1)) x = torch.randn(1, 4, 32, 32) - unet = SD1UNet(in_channels=4, clip_embedding_dim=768) + unet = SD1UNet(in_channels=4) unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s with torch.no_grad():