From 251277a0a898027479964272f047c4b6db93fbf6 Mon Sep 17 00:00:00 2001 From: Doryan Kaced Date: Fri, 22 Sep 2023 17:27:23 +0200 Subject: [PATCH] Fix module registration in IP-Adapter --- .../latent_diffusion/image_prompt.py | 28 ++++++++++++++----- tests/e2e/test_diffusion.py | 13 +++++---- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 9dc358d..1734729 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -162,6 +162,10 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]): + # Prevent PyTorch module registration + _clip_image_encoder: list[CLIPImageEncoderH] + _image_proj: list[ImageProjection] + def __init__( self, target: T, @@ -174,13 +178,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): cross_attn_2d = target.ensure_find(CrossAttentionBlock2d) - self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype) - self.image_proj = ImageProjection( - clip_image_embedding_dim=self.clip_image_encoder.output_dim, - clip_text_embedding_dim=cross_attn_2d.context_embedding_dim, - device=target.device, - dtype=target.dtype, - ) + self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)] + self._image_proj = [ + ImageProjection( + clip_image_embedding_dim=self.clip_image_encoder.output_dim, + clip_text_embedding_dim=cross_attn_2d.context_embedding_dim, + device=target.device, + dtype=target.dtype, + ) + ] self.sub_adapters = [ CrossAttentionAdapter(target=cross_attn, scale=scale) @@ -203,6 +209,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): cross_attn.load_state_dict(state_dict=cross_attn_state_dict) + @property + def clip_image_encoder(self) -> CLIPImageEncoderH: + return self._clip_image_encoder[0] + + @property + def image_proj(self) -> ImageProjection: + return self._image_proj[0] + def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter": for adapter in self.sub_adapters: adapter.inject() diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 11ff09f..5cfd424 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1125,6 +1125,13 @@ def test_diffusion_ip_adapter_controlnet( ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() + depth_controlnet = SD1ControlnetAdapter( + sd15.unet, + name="depth", + scale=1.0, + weights=load_from_safetensors(depth_cn_weights_path), + ).inject() + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image)) @@ -1138,12 +1145,6 @@ def test_diffusion_ip_adapter_controlnet( ) ) - depth_controlnet = SD1ControlnetAdapter( - sd15.unet, - name="depth", - scale=1.0, - weights=load_from_safetensors(depth_cn_weights_path), - ).inject() depth_cn_condition = image_to_tensor( depth_condition_image.convert("RGB"), device=test_device,