mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Fix module registration in IP-Adapter
This commit is contained in:
parent
72854de669
commit
251277a0a8
|
@ -162,6 +162,10 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
# Prevent PyTorch module registration
|
||||||
|
_clip_image_encoder: list[CLIPImageEncoderH]
|
||||||
|
_image_proj: list[ImageProjection]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: T,
|
target: T,
|
||||||
|
@ -174,13 +178,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
|
||||||
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||||
|
|
||||||
self.clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
self._clip_image_encoder = [clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)]
|
||||||
self.image_proj = ImageProjection(
|
self._image_proj = [
|
||||||
|
ImageProjection(
|
||||||
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
clip_image_embedding_dim=self.clip_image_encoder.output_dim,
|
||||||
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||||
device=target.device,
|
device=target.device,
|
||||||
dtype=target.dtype,
|
dtype=target.dtype,
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
self.sub_adapters = [
|
self.sub_adapters = [
|
||||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||||
|
@ -203,6 +209,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
|
||||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||||
|
return self._clip_image_encoder[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_proj(self) -> ImageProjection:
|
||||||
|
return self._image_proj[0]
|
||||||
|
|
||||||
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
||||||
for adapter in self.sub_adapters:
|
for adapter in self.sub_adapters:
|
||||||
adapter.inject()
|
adapter.inject()
|
||||||
|
|
|
@ -1125,6 +1125,13 @@ def test_diffusion_ip_adapter_controlnet(
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
|
depth_controlnet = SD1ControlnetAdapter(
|
||||||
|
sd15.unet,
|
||||||
|
name="depth",
|
||||||
|
scale=1.0,
|
||||||
|
weights=load_from_safetensors(depth_cn_weights_path),
|
||||||
|
).inject()
|
||||||
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))
|
||||||
|
|
||||||
|
@ -1138,12 +1145,6 @@ def test_diffusion_ip_adapter_controlnet(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
depth_controlnet = SD1ControlnetAdapter(
|
|
||||||
sd15.unet,
|
|
||||||
name="depth",
|
|
||||||
scale=1.0,
|
|
||||||
weights=load_from_safetensors(depth_cn_weights_path),
|
|
||||||
).inject()
|
|
||||||
depth_cn_condition = image_to_tensor(
|
depth_cn_condition = image_to_tensor(
|
||||||
depth_condition_image.convert("RGB"),
|
depth_condition_image.convert("RGB"),
|
||||||
device=test_device,
|
device=test_device,
|
||||||
|
|
Loading…
Reference in a new issue