mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
refactor fl.Parameter basic layer
Co-authored-by: Cédric Deltheil <cedric@deltheil.me>
This commit is contained in:
parent
11422a3faf
commit
807ef5551c
|
@ -93,7 +93,7 @@ def main() -> None:
|
||||||
if fine_grained:
|
if fine_grained:
|
||||||
w = image_proj_weights
|
w = image_proj_weights
|
||||||
image_proj_state_dict = {
|
image_proj_state_dict = {
|
||||||
"LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1
|
"LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1
|
||||||
"Linear_1.weight": w["proj_in.weight"],
|
"Linear_1.weight": w["proj_in.weight"],
|
||||||
"Linear_1.bias": w["proj_in.bias"],
|
"Linear_1.bias": w["proj_in.bias"],
|
||||||
"Linear_2.weight": w["proj_out.weight"],
|
"Linear_2.weight": w["proj_out.weight"],
|
||||||
|
|
|
@ -80,10 +80,10 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
||||||
mapping = converter.map_state_dicts(source_args=(x,))
|
mapping = converter.map_state_dicts(source_args=(x,))
|
||||||
assert mapping
|
assert mapping
|
||||||
|
|
||||||
mapping["PositionalEncoder.Parameter.parameter"] = "pos_embed"
|
mapping["PositionalEncoder.Parameter.weight"] = "pos_embed"
|
||||||
|
|
||||||
target_state_dict = refiners_sam_vit_h.state_dict()
|
target_state_dict = refiners_sam_vit_h.state_dict()
|
||||||
del target_state_dict["PositionalEncoder.Parameter.parameter"]
|
del target_state_dict["PositionalEncoder.Parameter.weight"]
|
||||||
|
|
||||||
source_state_dict = vit.state_dict()
|
source_state_dict = vit.state_dict()
|
||||||
pos_embed = source_state_dict["pos_embed"]
|
pos_embed = source_state_dict["pos_embed"]
|
||||||
|
@ -112,7 +112,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
||||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
converted_source["PositionalEncoder.Parameter.parameter"] = pos_embed # type: ignore
|
embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight)
|
||||||
|
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
|
||||||
converted_source.update(rel_items)
|
converted_source.update(rel_items)
|
||||||
|
|
||||||
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)
|
||||||
|
|
|
@ -64,9 +64,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
|
||||||
# Remove the class embedding from state dict since it was not mapped by the model converter
|
# Remove the class embedding from state dict since it was not mapped by the model converter
|
||||||
class_embedding = target.ensure_find(fl.Parameter)
|
class_embedding = target.ensure_find(fl.Parameter)
|
||||||
class_embedding_key = next(
|
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
|
||||||
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
|
|
||||||
)
|
|
||||||
assert class_embedding_key is not None
|
assert class_embedding_key is not None
|
||||||
assert class_embedding_key in target_state_dict
|
assert class_embedding_key in target_state_dict
|
||||||
del target_state_dict[class_embedding_key]
|
del target_state_dict[class_embedding_key]
|
||||||
|
@ -77,7 +75,8 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
target.load_state_dict(state_dict=converted_state_dict, strict=False)
|
target.load_state_dict(state_dict=converted_state_dict, strict=False)
|
||||||
|
|
||||||
# Ad hoc post-conversion steps
|
# Ad hoc post-conversion steps
|
||||||
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
|
embed = source.vision_model.embeddings.class_embedding
|
||||||
|
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
|
||||||
|
|
||||||
assert converter.compare_models((x,), threshold=args.threshold)
|
assert converter.compare_models((x,), threshold=args.threshold)
|
||||||
|
|
||||||
|
|
|
@ -302,7 +302,7 @@ convert_unclip () {
|
||||||
--from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \
|
--from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \
|
||||||
--to "tests/weights/CLIPImageEncoderH.safetensors" \
|
--to "tests/weights/CLIPImageEncoderH.safetensors" \
|
||||||
--half
|
--half
|
||||||
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 82918ff4
|
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 4ddb44d2
|
||||||
}
|
}
|
||||||
|
|
||||||
convert_ip_adapter () {
|
convert_ip_adapter () {
|
||||||
|
@ -321,13 +321,13 @@ convert_ip_adapter () {
|
||||||
--from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \
|
--from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \
|
||||||
--to "tests/weights/ip-adapter-plus_sd15.safetensors" \
|
--to "tests/weights/ip-adapter-plus_sd15.safetensors" \
|
||||||
--half
|
--half
|
||||||
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 346a31d1
|
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 842b20e2
|
||||||
|
|
||||||
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
||||||
--from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \
|
--from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \
|
||||||
--to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \
|
--to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \
|
||||||
--half
|
--half
|
||||||
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" d195feb3
|
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" 0409974b
|
||||||
}
|
}
|
||||||
|
|
||||||
convert_t2i_adapter () {
|
convert_t2i_adapter () {
|
||||||
|
@ -349,7 +349,7 @@ convert_sam () {
|
||||||
python scripts/conversion/convert_segment_anything.py \
|
python scripts/conversion/convert_segment_anything.py \
|
||||||
--from "tests/weights/sam_vit_h_4b8939.pth" \
|
--from "tests/weights/sam_vit_h_4b8939.pth" \
|
||||||
--to "tests/weights/segment-anything-h.safetensors"
|
--to "tests/weights/segment-anything-h.safetensors"
|
||||||
check_hash "tests/weights/segment-anything-h.safetensors" 321d6f23
|
check_hash "tests/weights/segment-anything-h.safetensors" 6b843800
|
||||||
}
|
}
|
||||||
|
|
||||||
download_all () {
|
download_all () {
|
||||||
|
|
|
@ -158,18 +158,10 @@ class Parameter(WeightedModule):
|
||||||
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
|
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
|
||||||
|
|
||||||
@property
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
def device(self) -> Device:
|
return self.weight.expand(x.shape[0], *self.dims)
|
||||||
return self.parameter.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self) -> DType:
|
|
||||||
return self.parameter.dtype
|
|
||||||
|
|
||||||
def forward(self, _: Tensor) -> Tensor:
|
|
||||||
return self.parameter
|
|
||||||
|
|
||||||
|
|
||||||
class Buffer(WeightedModule):
|
class Buffer(WeightedModule):
|
||||||
|
|
|
@ -3,18 +3,10 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
||||||
|
|
||||||
|
|
||||||
class ClassEncoder(fl.Chain):
|
class ClassToken(fl.Chain):
|
||||||
def __init__(
|
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self,
|
|
||||||
embedding_dim: int,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
super().__init__(
|
super().__init__(fl.Parameter(1, embedding_dim, device=device, dtype=dtype))
|
||||||
fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)),
|
|
||||||
fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(fl.Chain):
|
class PatchEncoder(fl.Chain):
|
||||||
|
@ -87,7 +79,7 @@ class ViTEmbeddings(fl.Chain):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Concatenate(
|
fl.Concatenate(
|
||||||
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
ClassToken(embedding_dim, device=device, dtype=dtype),
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
PatchEncoder(
|
PatchEncoder(
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
|
|
|
@ -165,18 +165,13 @@ class PerceiverAttention(fl.Chain):
|
||||||
return cat((x, latents), dim=-2)
|
return cat((x, latents), dim=-2)
|
||||||
|
|
||||||
|
|
||||||
class LatentsEncoder(fl.Chain):
|
class LatentsToken(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, num_tokens: int, latents_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||||
num_tokens: int,
|
|
||||||
embeddding_dim: int,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
self.num_tokens = num_tokens
|
||||||
fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)),
|
self.latents_dim = latents_dim
|
||||||
fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)),
|
super().__init__(fl.Parameter(num_tokens, latents_dim, device=device, dtype=dtype))
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Transformer(fl.Chain):
|
class Transformer(fl.Chain):
|
||||||
|
@ -211,7 +206,7 @@ class PerceiverResampler(fl.Chain):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
|
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
|
||||||
fl.SetContext(context="perceiver_resampler", key="x"),
|
fl.SetContext(context="perceiver_resampler", key="x"),
|
||||||
LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype),
|
LatentsToken(num_tokens, latents_dim, device=device, dtype=dtype),
|
||||||
Transformer(
|
Transformer(
|
||||||
TransformerLayer(
|
TransformerLayer(
|
||||||
fl.Residual(
|
fl.Residual(
|
||||||
|
|
|
@ -46,7 +46,6 @@ class PositionalEncoder(fl.Residual):
|
||||||
self.image_embedding_size = image_embedding_size
|
self.image_embedding_size = image_embedding_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Parameter(
|
fl.Parameter(
|
||||||
1,
|
|
||||||
image_embedding_size[0],
|
image_embedding_size[0],
|
||||||
image_embedding_size[1],
|
image_embedding_size[1],
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
|
|
Loading…
Reference in a new issue