diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 2bab229..0d48660 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -93,7 +93,7 @@ def main() -> None: if fine_grained: w = image_proj_weights image_proj_state_dict = { - "LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1 + "LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1 "Linear_1.weight": w["proj_in.weight"], "Linear_1.bias": w["proj_in.bias"], "Linear_2.weight": w["proj_out.weight"], diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index 64ab9f1..b4a0a3f 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -80,10 +80,10 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: mapping = converter.map_state_dicts(source_args=(x,)) assert mapping - mapping["PositionalEncoder.Parameter.parameter"] = "pos_embed" + mapping["PositionalEncoder.Parameter.weight"] = "pos_embed" target_state_dict = refiners_sam_vit_h.state_dict() - del target_state_dict["PositionalEncoder.Parameter.parameter"] + del target_state_dict["PositionalEncoder.Parameter.weight"] source_state_dict = vit.state_dict() pos_embed = source_state_dict["pos_embed"] @@ -112,7 +112,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping ) - converted_source["PositionalEncoder.Parameter.parameter"] = pos_embed # type: ignore + embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight) + converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore converted_source.update(rel_items) refiners_sam_vit_h.load_state_dict(state_dict=converted_source) diff --git a/scripts/conversion/convert_transformers_clip_image_model.py b/scripts/conversion/convert_transformers_clip_image_model.py index 53e896d..9ae73d8 100644 --- a/scripts/conversion/convert_transformers_clip_image_model.py +++ b/scripts/conversion/convert_transformers_clip_image_model.py @@ -64,9 +64,7 @@ def setup_converter(args: Args) -> ModelConverter: # Remove the class embedding from state dict since it was not mapped by the model converter class_embedding = target.ensure_find(fl.Parameter) - class_embedding_key = next( - (n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None - ) + class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None) assert class_embedding_key is not None assert class_embedding_key in target_state_dict del target_state_dict[class_embedding_key] @@ -77,7 +75,8 @@ def setup_converter(args: Args) -> ModelConverter: target.load_state_dict(state_dict=converted_state_dict, strict=False) # Ad hoc post-conversion steps - class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore + embed = source.vision_model.embeddings.class_embedding + class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore assert converter.compare_models((x,), threshold=args.threshold) diff --git a/scripts/prepare-test-weights.sh b/scripts/prepare-test-weights.sh index c4335ec..48e13da 100755 --- a/scripts/prepare-test-weights.sh +++ b/scripts/prepare-test-weights.sh @@ -302,7 +302,7 @@ convert_unclip () { --from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \ --to "tests/weights/CLIPImageEncoderH.safetensors" \ --half - check_hash "tests/weights/CLIPImageEncoderH.safetensors" 82918ff4 + check_hash "tests/weights/CLIPImageEncoderH.safetensors" 4ddb44d2 } convert_ip_adapter () { @@ -321,13 +321,13 @@ convert_ip_adapter () { --from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \ --to "tests/weights/ip-adapter-plus_sd15.safetensors" \ --half - check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 346a31d1 + check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 842b20e2 python scripts/conversion/convert_diffusers_ip_adapter.py \ --from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \ --to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \ --half - check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" d195feb3 + check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" 0409974b } convert_t2i_adapter () { @@ -349,7 +349,7 @@ convert_sam () { python scripts/conversion/convert_segment_anything.py \ --from "tests/weights/sam_vit_h_4b8939.pth" \ --to "tests/weights/segment-anything-h.safetensors" - check_hash "tests/weights/segment-anything-h.safetensors" 321d6f23 + check_hash "tests/weights/segment-anything-h.safetensors" 6b843800 } download_all () { diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index 9e6ff1f..cad3355 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -158,18 +158,10 @@ class Parameter(WeightedModule): def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__() self.dims = dims - self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype))) + self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype)) - @property - def device(self) -> Device: - return self.parameter.device - - @property - def dtype(self) -> DType: - return self.parameter.dtype - - def forward(self, _: Tensor) -> Tensor: - return self.parameter + def forward(self, x: Tensor) -> Tensor: + return self.weight.expand(x.shape[0], *self.dims) class Buffer(WeightedModule): diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py index 910cd31..b969ed1 100644 --- a/src/refiners/foundationals/clip/image_encoder.py +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -3,18 +3,10 @@ import refiners.fluxion.layers as fl from refiners.foundationals.clip.common import PositionalEncoder, FeedForward -class ClassEncoder(fl.Chain): - def __init__( - self, - embedding_dim: int, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: +class ClassToken(fl.Chain): + def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None: self.embedding_dim = embedding_dim - super().__init__( - fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)), - fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)), - ) + super().__init__(fl.Parameter(1, embedding_dim, device=device, dtype=dtype)) class PatchEncoder(fl.Chain): @@ -87,7 +79,7 @@ class ViTEmbeddings(fl.Chain): self.patch_size = patch_size super().__init__( fl.Concatenate( - ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype), + ClassToken(embedding_dim, device=device, dtype=dtype), fl.Chain( PatchEncoder( in_channels=3, diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 50cab33..72e1247 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -165,18 +165,13 @@ class PerceiverAttention(fl.Chain): return cat((x, latents), dim=-2) -class LatentsEncoder(fl.Chain): +class LatentsToken(fl.Chain): def __init__( - self, - num_tokens: int, - embeddding_dim: int, - device: Device | str | None = None, - dtype: DType | None = None, + self, num_tokens: int, latents_dim: int, device: Device | str | None = None, dtype: DType | None = None ) -> None: - super().__init__( - fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)), - fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)), - ) + self.num_tokens = num_tokens + self.latents_dim = latents_dim + super().__init__(fl.Parameter(num_tokens, latents_dim, device=device, dtype=dtype)) class Transformer(fl.Chain): @@ -211,7 +206,7 @@ class PerceiverResampler(fl.Chain): super().__init__( fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype), fl.SetContext(context="perceiver_resampler", key="x"), - LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype), + LatentsToken(num_tokens, latents_dim, device=device, dtype=dtype), Transformer( TransformerLayer( fl.Residual( diff --git a/src/refiners/foundationals/segment_anything/image_encoder.py b/src/refiners/foundationals/segment_anything/image_encoder.py index 4a4f7e7..1e02e33 100644 --- a/src/refiners/foundationals/segment_anything/image_encoder.py +++ b/src/refiners/foundationals/segment_anything/image_encoder.py @@ -46,7 +46,6 @@ class PositionalEncoder(fl.Residual): self.image_embedding_size = image_embedding_size super().__init__( fl.Parameter( - 1, image_embedding_size[0], image_embedding_size[1], embedding_dim,