mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
refactor fl.Parameter basic layer
Co-authored-by: Cédric Deltheil <cedric@deltheil.me>
This commit is contained in:
parent
11422a3faf
commit
807ef5551c
|
@ -93,7 +93,7 @@ def main() -> None:
|
|||
if fine_grained:
|
||||
w = image_proj_weights
|
||||
image_proj_state_dict = {
|
||||
"LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1
|
||||
"LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1
|
||||
"Linear_1.weight": w["proj_in.weight"],
|
||||
"Linear_1.bias": w["proj_in.bias"],
|
||||
"Linear_2.weight": w["proj_out.weight"],
|
||||
|
|
|
@ -80,10 +80,10 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
|||
mapping = converter.map_state_dicts(source_args=(x,))
|
||||
assert mapping
|
||||
|
||||
mapping["PositionalEncoder.Parameter.parameter"] = "pos_embed"
|
||||
mapping["PositionalEncoder.Parameter.weight"] = "pos_embed"
|
||||
|
||||
target_state_dict = refiners_sam_vit_h.state_dict()
|
||||
del target_state_dict["PositionalEncoder.Parameter.parameter"]
|
||||
del target_state_dict["PositionalEncoder.Parameter.weight"]
|
||||
|
||||
source_state_dict = vit.state_dict()
|
||||
pos_embed = source_state_dict["pos_embed"]
|
||||
|
@ -112,7 +112,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
|||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||
)
|
||||
|
||||
converted_source["PositionalEncoder.Parameter.parameter"] = pos_embed # type: ignore
|
||||
embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight)
|
||||
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
|
||||
converted_source.update(rel_items)
|
||||
|
||||
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
||||
|
|
|
@ -64,9 +64,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
|
||||
# Remove the class embedding from state dict since it was not mapped by the model converter
|
||||
class_embedding = target.ensure_find(fl.Parameter)
|
||||
class_embedding_key = next(
|
||||
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
|
||||
)
|
||||
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
|
||||
assert class_embedding_key is not None
|
||||
assert class_embedding_key in target_state_dict
|
||||
del target_state_dict[class_embedding_key]
|
||||
|
@ -77,7 +75,8 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
target.load_state_dict(state_dict=converted_state_dict, strict=False)
|
||||
|
||||
# Ad hoc post-conversion steps
|
||||
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
|
||||
embed = source.vision_model.embeddings.class_embedding
|
||||
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
|
||||
|
||||
assert converter.compare_models((x,), threshold=args.threshold)
|
||||
|
||||
|
|
|
@ -302,7 +302,7 @@ convert_unclip () {
|
|||
--from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \
|
||||
--to "tests/weights/CLIPImageEncoderH.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 82918ff4
|
||||
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 4ddb44d2
|
||||
}
|
||||
|
||||
convert_ip_adapter () {
|
||||
|
@ -321,13 +321,13 @@ convert_ip_adapter () {
|
|||
--from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \
|
||||
--to "tests/weights/ip-adapter-plus_sd15.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 346a31d1
|
||||
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 842b20e2
|
||||
|
||||
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
||||
--from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \
|
||||
--to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" d195feb3
|
||||
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" 0409974b
|
||||
}
|
||||
|
||||
convert_t2i_adapter () {
|
||||
|
@ -349,7 +349,7 @@ convert_sam () {
|
|||
python scripts/conversion/convert_segment_anything.py \
|
||||
--from "tests/weights/sam_vit_h_4b8939.pth" \
|
||||
--to "tests/weights/segment-anything-h.safetensors"
|
||||
check_hash "tests/weights/segment-anything-h.safetensors" 321d6f23
|
||||
check_hash "tests/weights/segment-anything-h.safetensors" 6b843800
|
||||
}
|
||||
|
||||
download_all () {
|
||||
|
|
|
@ -158,18 +158,10 @@ class Parameter(WeightedModule):
|
|||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
|
||||
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
return self.parameter.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return self.parameter.dtype
|
||||
|
||||
def forward(self, _: Tensor) -> Tensor:
|
||||
return self.parameter
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.weight.expand(x.shape[0], *self.dims)
|
||||
|
||||
|
||||
class Buffer(WeightedModule):
|
||||
|
|
|
@ -3,18 +3,10 @@ import refiners.fluxion.layers as fl
|
|||
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
||||
|
||||
|
||||
class ClassEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
class ClassToken(fl.Chain):
|
||||
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.embedding_dim = embedding_dim
|
||||
super().__init__(
|
||||
fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)),
|
||||
fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)),
|
||||
)
|
||||
super().__init__(fl.Parameter(1, embedding_dim, device=device, dtype=dtype))
|
||||
|
||||
|
||||
class PatchEncoder(fl.Chain):
|
||||
|
@ -87,7 +79,7 @@ class ViTEmbeddings(fl.Chain):
|
|||
self.patch_size = patch_size
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
ClassToken(embedding_dim, device=device, dtype=dtype),
|
||||
fl.Chain(
|
||||
PatchEncoder(
|
||||
in_channels=3,
|
||||
|
|
|
@ -165,18 +165,13 @@ class PerceiverAttention(fl.Chain):
|
|||
return cat((x, latents), dim=-2)
|
||||
|
||||
|
||||
class LatentsEncoder(fl.Chain):
|
||||
class LatentsToken(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
num_tokens: int,
|
||||
embeddding_dim: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
self, num_tokens: int, latents_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)),
|
||||
fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)),
|
||||
)
|
||||
self.num_tokens = num_tokens
|
||||
self.latents_dim = latents_dim
|
||||
super().__init__(fl.Parameter(num_tokens, latents_dim, device=device, dtype=dtype))
|
||||
|
||||
|
||||
class Transformer(fl.Chain):
|
||||
|
@ -211,7 +206,7 @@ class PerceiverResampler(fl.Chain):
|
|||
super().__init__(
|
||||
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
|
||||
fl.SetContext(context="perceiver_resampler", key="x"),
|
||||
LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype),
|
||||
LatentsToken(num_tokens, latents_dim, device=device, dtype=dtype),
|
||||
Transformer(
|
||||
TransformerLayer(
|
||||
fl.Residual(
|
||||
|
|
|
@ -46,7 +46,6 @@ class PositionalEncoder(fl.Residual):
|
|||
self.image_embedding_size = image_embedding_size
|
||||
super().__init__(
|
||||
fl.Parameter(
|
||||
1,
|
||||
image_embedding_size[0],
|
||||
image_embedding_size[1],
|
||||
embedding_dim,
|
||||
|
|
Loading…
Reference in a new issue