mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Fix module registration in IP-Adapter
This commit is contained in:
parent
72854de669
commit
251277a0a8
|
@ -162,6 +162,10 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
|
||||
|
||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
# Prevent PyTorch module registration
|
||||
_clip_image_encoder: list[CLIPImageEncoderH]
|
||||
_image_proj: list[ImageProjection]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
|
@ -174,13 +178,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
|
||||
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||
|
||||
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||
self.image_proj = ImageProjection(
|
||||
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)]
|
||||
self._image_proj = [
|
||||
ImageProjection(
|
||||
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
)
|
||||
]
|
||||
|
||||
self.sub_adapters = [
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||
|
@ -203,6 +209,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
|
||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||
|
||||
@property
|
||||
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||
return self._clip_image_encoder[0]
|
||||
|
||||
@property
|
||||
def image_proj(self) -> ImageProjection:
|
||||
return self._image_proj[0]
|
||||
|
||||
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.inject()
|
||||
|
|
|
@ -1125,6 +1125,13 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
depth_controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet,
|
||||
name="depth",
|
||||
scale=1.0,
|
||||
weights=load_from_safetensors(depth_cn_weights_path),
|
||||
).inject()
|
||||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
||||
|
||||
|
@ -1138,12 +1145,6 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
)
|
||||
)
|
||||
|
||||
depth_controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet,
|
||||
name="depth",
|
||||
scale=1.0,
|
||||
weights=load_from_safetensors(depth_cn_weights_path),
|
||||
).inject()
|
||||
depth_cn_condition = image_to_tensor(
|
||||
depth_condition_image.convert("RGB"),
|
||||
device=test_device,
|
||||
|
|
Loading…
Reference in a new issue