Fix module registration in IP-Adapter

This commit is contained in:
Doryan Kaced 2023-09-22 17:27:23 +02:00
parent 72854de669
commit 251277a0a8
2 changed files with 28 additions and 13 deletions

View file

@ -162,6 +162,10 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
# Prevent PyTorch module registration
_clip_image_encoder: list[CLIPImageEncoderH]
_image_proj: list[ImageProjection]
def __init__(
self,
target: T,
@ -174,13 +178,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
self.image_proj = ImageProjection(
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)]
self._image_proj = [
ImageProjection(
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
]
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale)
@ -203,6 +209,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
return self._clip_image_encoder[0]
@property
def image_proj(self) -> ImageProjection:
return self._image_proj[0]
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
for adapter in self.sub_adapters:
adapter.inject()

View file

@ -1125,6 +1125,13 @@ def test_diffusion_ip_adapter_controlnet(
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
depth_controlnet = SD1ControlnetAdapter(
sd15.unet,
name="depth",
scale=1.0,
weights=load_from_safetensors(depth_cn_weights_path),
).inject()
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
@ -1138,12 +1145,6 @@ def test_diffusion_ip_adapter_controlnet(
)
)
depth_controlnet = SD1ControlnetAdapter(
sd15.unet,
name="depth",
scale=1.0,
weights=load_from_safetensors(depth_cn_weights_path),
).inject()
depth_cn_condition = image_to_tensor(
depth_condition_image.convert("RGB"),
device=test_device,