refactor fl.Parameter basic layer

Co-authored-by: Cédric Deltheil <cedric@deltheil.me>
This commit is contained in:
limiteinductive 2023-12-08 10:12:37 +01:00 committed by Benjamin Trom
parent 11422a3faf
commit 807ef5551c
8 changed files with 25 additions and 47 deletions

View file

@ -93,7 +93,7 @@ def main() -> None:
if fine_grained:
w = image_proj_weights
image_proj_state_dict = {
"LatentsEncoder.Parallel.Parameter.parameter": w["latents"].squeeze(0), # drop batch dim = 1
"LatentsToken.Parameter.weight": w["latents"].squeeze(0), # drop batch dim = 1
"Linear_1.weight": w["proj_in.weight"],
"Linear_1.bias": w["proj_in.bias"],
"Linear_2.weight": w["proj_out.weight"],

View file

@ -80,10 +80,10 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
mapping = converter.map_state_dicts(source_args=(x,))
assert mapping
mapping["PositionalEncoder.Parameter.parameter"] = "pos_embed"
mapping["PositionalEncoder.Parameter.weight"] = "pos_embed"
target_state_dict = refiners_sam_vit_h.state_dict()
del target_state_dict["PositionalEncoder.Parameter.parameter"]
del target_state_dict["PositionalEncoder.Parameter.weight"]
source_state_dict = vit.state_dict()
pos_embed = source_state_dict["pos_embed"]
@ -112,7 +112,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
converted_source["PositionalEncoder.Parameter.parameter"] = pos_embed # type: ignore
embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight)
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
converted_source.update(rel_items)
refiners_sam_vit_h.load_state_dict(state_dict=converted_source)

View file

@ -64,9 +64,7 @@ def setup_converter(args: Args) -> ModelConverter:
# Remove the class embedding from state dict since it was not mapped by the model converter
class_embedding = target.ensure_find(fl.Parameter)
class_embedding_key = next(
(n for n, p in target.named_parameters() if id(p) == id(class_embedding.parameter)), None
)
class_embedding_key = next((n for n, p in target.named_parameters() if id(p) == id(class_embedding.weight)), None)
assert class_embedding_key is not None
assert class_embedding_key in target_state_dict
del target_state_dict[class_embedding_key]
@ -77,7 +75,8 @@ def setup_converter(args: Args) -> ModelConverter:
target.load_state_dict(state_dict=converted_state_dict, strict=False)
# Ad hoc post-conversion steps
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
embed = source.vision_model.embeddings.class_embedding
class_embedding.weight = torch.nn.Parameter(embed.clone().reshape_as(class_embedding.weight)) # type: ignore
assert converter.compare_models((x,), threshold=args.threshold)

View file

@ -302,7 +302,7 @@ convert_unclip () {
--from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \
--to "tests/weights/CLIPImageEncoderH.safetensors" \
--half
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 82918ff4
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 4ddb44d2
}
convert_ip_adapter () {
@ -321,13 +321,13 @@ convert_ip_adapter () {
--from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \
--to "tests/weights/ip-adapter-plus_sd15.safetensors" \
--half
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 346a31d1
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 842b20e2
python scripts/conversion/convert_diffusers_ip_adapter.py \
--from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \
--to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \
--half
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" d195feb3
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" 0409974b
}
convert_t2i_adapter () {
@ -349,7 +349,7 @@ convert_sam () {
python scripts/conversion/convert_segment_anything.py \
--from "tests/weights/sam_vit_h_4b8939.pth" \
--to "tests/weights/segment-anything-h.safetensors"
check_hash "tests/weights/segment-anything-h.safetensors" 321d6f23
check_hash "tests/weights/segment-anything-h.safetensors" 6b843800
}
download_all () {

View file

@ -158,18 +158,10 @@ class Parameter(WeightedModule):
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.register_parameter("parameter", TorchParameter(randn(*dims, device=device, dtype=dtype)))
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
@property
def device(self) -> Device:
return self.parameter.device
@property
def dtype(self) -> DType:
return self.parameter.dtype
def forward(self, _: Tensor) -> Tensor:
return self.parameter
def forward(self, x: Tensor) -> Tensor:
return self.weight.expand(x.shape[0], *self.dims)
class Buffer(WeightedModule):

View file

@ -3,18 +3,10 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
class ClassEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
class ClassToken(fl.Chain):
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.embedding_dim = embedding_dim
super().__init__(
fl.Parallel(fl.Identity(), fl.Parameter(embedding_dim, device=device, dtype=dtype)),
fl.Lambda(lambda x, p: p.expand(x.shape[0], 1, -1)),
)
super().__init__(fl.Parameter(1, embedding_dim, device=device, dtype=dtype))
class PatchEncoder(fl.Chain):
@ -87,7 +79,7 @@ class ViTEmbeddings(fl.Chain):
self.patch_size = patch_size
super().__init__(
fl.Concatenate(
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
ClassToken(embedding_dim, device=device, dtype=dtype),
fl.Chain(
PatchEncoder(
in_channels=3,

View file

@ -165,18 +165,13 @@ class PerceiverAttention(fl.Chain):
return cat((x, latents), dim=-2)
class LatentsEncoder(fl.Chain):
class LatentsToken(fl.Chain):
def __init__(
self,
num_tokens: int,
embeddding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
self, num_tokens: int, latents_dim: int, device: Device | str | None = None, dtype: DType | None = None
) -> None:
super().__init__(
fl.Parallel(fl.Identity(), fl.Parameter(num_tokens, embeddding_dim, device=device, dtype=dtype)),
fl.Lambda(lambda x, p: p.expand(x.shape[0], -1, -1)),
)
self.num_tokens = num_tokens
self.latents_dim = latents_dim
super().__init__(fl.Parameter(num_tokens, latents_dim, device=device, dtype=dtype))
class Transformer(fl.Chain):
@ -211,7 +206,7 @@ class PerceiverResampler(fl.Chain):
super().__init__(
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
fl.SetContext(context="perceiver_resampler", key="x"),
LatentsEncoder(num_tokens=num_tokens, embeddding_dim=latents_dim, device=device, dtype=dtype),
LatentsToken(num_tokens, latents_dim, device=device, dtype=dtype),
Transformer(
TransformerLayer(
fl.Residual(

View file

@ -46,7 +46,6 @@ class PositionalEncoder(fl.Residual):
self.image_embedding_size = image_embedding_size
super().__init__(
fl.Parameter(
1,
image_embedding_size[0],
image_embedding_size[1],
embedding_dim,