From ea44262a3981acb832f682aacf659cefdae35cb7 Mon Sep 17 00:00:00 2001 From: Benjamin Trom Date: Thu, 19 Oct 2023 10:17:25 +0200 Subject: [PATCH] unnest Residual subchain by modifying its forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And replaced the remaining Sum-Identity layers by Residual. The tolerance used to compare SAM's ViT models has been tweaked: for some reasons there is a small difference (in float32) in the neck layer (first conv2D) Co-authored-by: Cédric Deltheil --- .../convert_diffusers_controlnet.py | 28 +++--- .../convert_diffusers_ip_adapter.py | 22 ++--- .../conversion/convert_segment_anything.py | 14 +-- scripts/prepare-test-weights.sh | 50 +++++++---- src/refiners/fluxion/layers/chain.py | 10 ++- .../latent_diffusion/auto_encoder.py | 19 ++-- .../latent_diffusion/cross_attention.py | 88 ++++++++----------- .../stable_diffusion_1/controlnet.py | 11 +-- .../segment_anything/test_sam.py | 2 +- 9 files changed, 119 insertions(+), 125 deletions(-) diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index 92d9b0a..11493cb 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -65,11 +65,11 @@ def convert(args: Args) -> dict[str, torch.Tensor]: expected_target_order = [ "DownBlocks.Chain_1.Passthrough.Conv2d", - "DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "DownBlocks.Chain_2.CLIPLCrossAttention.Chain_1.Conv2d", + "DownBlocks.Chain_2.CLIPLCrossAttention.Chain_3.Conv2d", "DownBlocks.Chain_2.Passthrough.Conv2d", - "DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "DownBlocks.Chain_3.CLIPLCrossAttention.Chain_1.Conv2d", + "DownBlocks.Chain_3.CLIPLCrossAttention.Chain_3.Conv2d", "DownBlocks.Chain_3.Passthrough.Conv2d", "DownBlocks.Chain_4.Passthrough.Conv2d", ] @@ -102,11 +102,11 @@ def convert(args: Args) -> dict[str, torch.Tensor]: ] expected_target_order = [ - "DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "DownBlocks.Chain_5.CLIPLCrossAttention.Chain_1.Conv2d", + "DownBlocks.Chain_5.CLIPLCrossAttention.Chain_3.Conv2d", "DownBlocks.Chain_5.Passthrough.Conv2d", - "DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "DownBlocks.Chain_6.CLIPLCrossAttention.Chain_1.Conv2d", + "DownBlocks.Chain_6.CLIPLCrossAttention.Chain_3.Conv2d", "DownBlocks.Chain_6.Passthrough.Conv2d", "DownBlocks.Chain_7.Passthrough.Conv2d", ] @@ -143,17 +143,17 @@ def convert(args: Args) -> dict[str, torch.Tensor]: ] expected_target_order = [ - "DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "DownBlocks.Chain_8.CLIPLCrossAttention.Chain_1.Conv2d", + "DownBlocks.Chain_8.CLIPLCrossAttention.Chain_3.Conv2d", "DownBlocks.Chain_8.Passthrough.Conv2d", - "DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "DownBlocks.Chain_9.CLIPLCrossAttention.Chain_1.Conv2d", + "DownBlocks.Chain_9.CLIPLCrossAttention.Chain_3.Conv2d", "DownBlocks.Chain_9.Passthrough.Conv2d", "DownBlocks.Chain_10.Passthrough.Conv2d", "DownBlocks.Chain_11.Passthrough.Conv2d", "DownBlocks.Chain_12.Passthrough.Conv2d", - "MiddleBlock.CLIPLCrossAttention.Chain.Chain_1.Conv2d", - "MiddleBlock.CLIPLCrossAttention.Chain.Chain_3.Conv2d", + "MiddleBlock.CLIPLCrossAttention.Chain_1.Conv2d", + "MiddleBlock.CLIPLCrossAttention.Chain_3.Conv2d", "MiddleBlock.Passthrough.Conv2d", ] diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index f2a6b84..2bab229 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -105,17 +105,17 @@ def main() -> None: t_pfx, s_pfx = f"Transformer.TransformerLayer_{i+1}.Residual_", f"layers.{i}." image_proj_state_dict.update( { - f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"], - f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"], - f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"], - f"{t_pfx}1.Chain.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"], - f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"], - f"{t_pfx}1.Chain.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"], - f"{t_pfx}1.Chain.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"], - f"{t_pfx}2.Chain.LayerNorm.weight": w[f"{s_pfx}1.0.weight"], - f"{t_pfx}2.Chain.LayerNorm.bias": w[f"{s_pfx}1.0.bias"], - f"{t_pfx}2.Chain.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"], - f"{t_pfx}2.Chain.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"], + f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.weight": w[f"{s_pfx}0.norm1.weight"], + f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_1.bias": w[f"{s_pfx}0.norm1.bias"], + f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.weight": w[f"{s_pfx}0.norm2.weight"], + f"{t_pfx}1.PerceiverAttention.Distribute.LayerNorm_2.bias": w[f"{s_pfx}0.norm2.bias"], + f"{t_pfx}1.PerceiverAttention.Parallel.Chain_2.Linear.weight": w[f"{s_pfx}0.to_q.weight"], + f"{t_pfx}1.PerceiverAttention.Parallel.Chain_1.Linear.weight": w[f"{s_pfx}0.to_kv.weight"], + f"{t_pfx}1.PerceiverAttention.Linear.weight": w[f"{s_pfx}0.to_out.weight"], + f"{t_pfx}2.LayerNorm.weight": w[f"{s_pfx}1.0.weight"], + f"{t_pfx}2.LayerNorm.bias": w[f"{s_pfx}1.0.bias"], + f"{t_pfx}2.FeedForward.Linear_1.weight": w[f"{s_pfx}1.1.weight"], + f"{t_pfx}2.FeedForward.Linear_2.weight": w[f"{s_pfx}1.3.weight"], } ) else: diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index fdfd9bd..64ab9f1 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -55,7 +55,7 @@ def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore assert isinstance(pe, Tensor) state_dict: dict[str, Tensor] = { - "Residual.Chain.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)), + "Residual.PointTypeEmbedding.weight": nn.Parameter(data=torch.cat(tensors=point_embeddings, dim=0)), "CoordinateEncoder.Linear.weight": nn.Parameter(data=pe.T.contiguous()), } @@ -80,10 +80,10 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: mapping = converter.map_state_dicts(source_args=(x,)) assert mapping - mapping["PositionalEncoder.Chain.Parameter.parameter"] = "pos_embed" + mapping["PositionalEncoder.Parameter.parameter"] = "pos_embed" target_state_dict = refiners_sam_vit_h.state_dict() - del target_state_dict["PositionalEncoder.Chain.Parameter.parameter"] + del target_state_dict["PositionalEncoder.Parameter.parameter"] source_state_dict = vit.state_dict() pos_embed = source_state_dict["pos_embed"] @@ -91,8 +91,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: target_rel_keys = [ ( - f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.horizontal_embedding", - f"Transformer.TransformerLayer_{i}.Residual_1.Chain.FusedSelfAttention.RelativePositionAttention.vertical_embedding", + f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.horizontal_embedding", + f"Transformer.TransformerLayer_{i}.Residual_1.FusedSelfAttention.RelativePositionAttention.vertical_embedding", ) for i in range(1, 33) ] @@ -112,11 +112,11 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping ) - converted_source["PositionalEncoder.Chain.Parameter.parameter"] = pos_embed # type: ignore + converted_source["PositionalEncoder.Parameter.parameter"] = pos_embed # type: ignore converted_source.update(rel_items) refiners_sam_vit_h.load_state_dict(state_dict=converted_source) - assert converter.compare_models((x,), threshold=1e-3) + assert converter.compare_models((x,), threshold=1e-2) return converted_source diff --git a/scripts/prepare-test-weights.sh b/scripts/prepare-test-weights.sh index c52986b..c4335ec 100755 --- a/scripts/prepare-test-weights.sh +++ b/scripts/prepare-test-weights.sh @@ -176,6 +176,12 @@ download_t2i_adapter () { curl -LO https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/raw/main/config.json curl -LO https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/resolve/main/diffusion_pytorch_model.bin popd + + mkdir -p tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0 + pushd tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0 + curl -LO https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/raw/main/config.json + curl -LO https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors + popd } download_sam () { @@ -191,18 +197,18 @@ convert_sd15 () { --from "tests/weights/runwayml/stable-diffusion-v1-5" \ --to "tests/weights/CLIPTextEncoderL.safetensors" \ --half - check_hash "tests/weights/CLIPTextEncoderL.safetensors" bef71657 + check_hash "tests/weights/CLIPTextEncoderL.safetensors" 6c9cbc59 python scripts/conversion/convert_diffusers_autoencoder_kl.py \ --from "tests/weights/runwayml/stable-diffusion-v1-5" \ --to "tests/weights/lda.safetensors" - check_hash "tests/weights/lda.safetensors" 28f38b35 + check_hash "tests/weights/lda.safetensors" 329e369c python scripts/conversion/convert_diffusers_unet.py \ --from "tests/weights/runwayml/stable-diffusion-v1-5" \ --to "tests/weights/unet.safetensors" \ --half - check_hash "tests/weights/unet.safetensors" d283a9a5 + check_hash "tests/weights/unet.safetensors" f81ac65a mkdir tests/weights/inpainting @@ -210,7 +216,7 @@ convert_sd15 () { --from "tests/weights/runwayml/stable-diffusion-inpainting" \ --to "tests/weights/inpainting/unet.safetensors" \ --half - check_hash "tests/weights/inpainting/unet.safetensors" 78069e20 + check_hash "tests/weights/inpainting/unet.safetensors" c07a8c61 } convert_sdxl () { @@ -219,19 +225,19 @@ convert_sdxl () { --to "tests/weights/DoubleCLIPTextEncoder.safetensors" \ --subfolder2 text_encoder_2 \ --half - check_hash "tests/weights/DoubleCLIPTextEncoder.safetensors" a68fd375 + check_hash "tests/weights/DoubleCLIPTextEncoder.safetensors" 7f99c30b python scripts/conversion/convert_diffusers_autoencoder_kl.py \ --from "tests/weights/stabilityai/stable-diffusion-xl-base-1.0" \ --to "tests/weights/sdxl-lda.safetensors" \ --half - check_hash "tests/weights/sdxl-lda.safetensors" b00aaf87 + check_hash "tests/weights/sdxl-lda.safetensors" 7464e9dc python scripts/conversion/convert_diffusers_unet.py \ --from "tests/weights/stabilityai/stable-diffusion-xl-base-1.0" \ --to "tests/weights/sdxl-unet.safetensors" \ --half - check_hash "tests/weights/sdxl-unet.safetensors" 861b57fd + check_hash "tests/weights/sdxl-unet.safetensors" 2e5c4911 } convert_vae_ft_mse () { @@ -239,7 +245,7 @@ convert_vae_ft_mse () { --from "tests/weights/stabilityai/sd-vae-ft-mse" \ --to "tests/weights/lda_ft_mse.safetensors" \ --half - check_hash "tests/weights/lda_ft_mse.safetensors" 6cfb7776 + check_hash "tests/weights/lda_ft_mse.safetensors" 4d0bae7e } convert_lora () { @@ -259,7 +265,7 @@ convert_preprocessors () { --from "tests/weights/carolineec/informativedrawings/model2.pth" \ --to "tests/weights/informative-drawings.safetensors" rm -f src/model.py - check_hash "tests/weights/informative-drawings.safetensors" 0294ac8a + check_hash "tests/weights/informative-drawings.safetensors" 93dca207 } convert_controlnet () { @@ -268,27 +274,27 @@ convert_controlnet () { python scripts/conversion/convert_diffusers_controlnet.py \ --from "tests/weights/lllyasviel/control_v11p_sd15_canny" \ --to "tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors" - check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors" be9ffe47 + check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors" 9a1a48cf python scripts/conversion/convert_diffusers_controlnet.py \ --from "tests/weights/lllyasviel/control_v11f1p_sd15_depth" \ --to "tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors" - check_hash "tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors" bbeaa1ba + check_hash "tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors" bbe7e5a6 python scripts/conversion/convert_diffusers_controlnet.py \ --from "tests/weights/lllyasviel/control_v11p_sd15_normalbae" \ --to "tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors" - check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors" 24520c5b + check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors" 9fa88ed5 python scripts/conversion/convert_diffusers_controlnet.py \ --from "tests/weights/lllyasviel/control_v11p_sd15_lineart" \ --to "tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors" - check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors" 5bc4de82 + check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors" c29e8c03 python scripts/conversion/convert_diffusers_controlnet.py \ --from "tests/weights/mfidabel/controlnet-segment-anything" \ --to "tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors" - check_hash "tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors" ba7059fc + check_hash "tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors" d536eebb } convert_unclip () { @@ -296,7 +302,7 @@ convert_unclip () { --from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \ --to "tests/weights/CLIPImageEncoderH.safetensors" \ --half - check_hash "tests/weights/CLIPImageEncoderH.safetensors" 654842e4 + check_hash "tests/weights/CLIPImageEncoderH.safetensors" 82918ff4 } convert_ip_adapter () { @@ -315,13 +321,13 @@ convert_ip_adapter () { --from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \ --to "tests/weights/ip-adapter-plus_sd15.safetensors" \ --half - check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 9cea790f + check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 346a31d1 python scripts/conversion/convert_diffusers_ip_adapter.py \ --from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \ --to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \ --half - check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" a090ab44 + check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" d195feb3 } convert_t2i_adapter () { @@ -330,14 +336,20 @@ convert_t2i_adapter () { --from "tests/weights/TencentARC/t2iadapter_depth_sd15v2" \ --to "tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors" \ --half - check_hash "tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors" 809a355f + check_hash "tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors" bb2b3115 + + python scripts/conversion/convert_diffusers_t2i_adapter.py \ + --from "tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0" \ + --to "tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors" \ + --half + check_hash "tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors" f07249a6 } convert_sam () { python scripts/conversion/convert_segment_anything.py \ --from "tests/weights/sam_vit_h_4b8939.pth" \ --to "tests/weights/segment-anything-h.safetensors" - check_hash "tests/weights/segment-anything-h.safetensors" e11e1ec5 + check_hash "tests/weights/segment-anything-h.safetensors" 321d6f23 } download_all () { diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index c6a28dd..bf99be1 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -6,7 +6,6 @@ import traceback from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload import torch from torch import Tensor, cat, device as Device, dtype as DType -from refiners.fluxion.layers.basics import Identity from refiners.fluxion.layers.module import Module, ContextModule, ModuleTree, WeightedModule from refiners.fluxion.context import Contexts, ContextProvider from refiners.fluxion.utils import summarize_tensor @@ -530,9 +529,12 @@ class Sum(Chain): return self.__class__ == Sum -class Residual(Sum): - def __init__(self, *modules: Module) -> None: - super().__init__(Identity(), Chain(*modules)) +class Residual(Chain): + _tag = "RES" + + def forward(self, *inputs: Any) -> Any: + assert len(inputs) == 1, "Residual connection can only be used with a single input." + return super().forward(*inputs) + inputs[0] class Breakpoint(ContextModule): diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index cdd2018..3d0bdd3 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -10,6 +10,7 @@ from refiners.fluxion.layers import ( Sum, SelfAttention2d, Slicing, + Residual, ) from refiners.fluxion.utils import image_to_tensor, tensor_to_image from torch import Tensor, device as Device, dtype as DType @@ -82,12 +83,9 @@ class Encoder(Chain): channels: int = layer[-1].out_channels # type: ignore layer.append(Downsample(channels=channels, scale_factor=2, device=device, dtype=dtype)) - attention_layer = Sum( - Identity(), - Chain( - GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype), - SelfAttention2d(channels=resnet_sizes[-1], device=device, dtype=dtype), - ), + attention_layer = Residual( + GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype), + SelfAttention2d(channels=resnet_sizes[-1], device=device, dtype=dtype), ) resnet_layers[-1].insert_after_type(Resnet, attention_layer) super().__init__( @@ -152,12 +150,9 @@ class Decoder(Chain): ) for i in range(len(resnet_sizes)) ] - attention_layer = Sum( - Identity(), - Chain( - GroupNorm(channels=resnet_sizes[0], num_groups=32, eps=1e-6, device=device, dtype=dtype), - SelfAttention2d(channels=resnet_sizes[0], device=device, dtype=dtype), - ), + attention_layer = Residual( + GroupNorm(channels=resnet_sizes[0], num_groups=32, eps=1e-6, device=device, dtype=dtype), + SelfAttention2d(channels=resnet_sizes[0], device=device, dtype=dtype), ) resnet_layers[0].insert(1, attention_layer) for _, layer in zip(range(3), resnet_layers[1:]): diff --git a/src/refiners/foundationals/latent_diffusion/cross_attention.py b/src/refiners/foundationals/latent_diffusion/cross_attention.py index ca16bdb..18fd7a6 100644 --- a/src/refiners/foundationals/latent_diffusion/cross_attention.py +++ b/src/refiners/foundationals/latent_diffusion/cross_attention.py @@ -10,7 +10,6 @@ from refiners.fluxion.layers import ( Parallel, LayerNorm, Attention, - Sum, UseContext, Linear, GLU, @@ -19,6 +18,7 @@ from refiners.fluxion.layers import ( Conv2d, SelfAttention, SetContext, + Residual, ) @@ -41,43 +41,34 @@ class CrossAttentionBlock(Chain): self.use_bias = use_bias super().__init__( - Sum( - Identity(), - Chain( - LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), - SelfAttention( - embedding_dim=embedding_dim, num_heads=num_heads, use_bias=use_bias, device=device, dtype=dtype - ), + Residual( + LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + SelfAttention( + embedding_dim=embedding_dim, num_heads=num_heads, use_bias=use_bias, device=device, dtype=dtype ), ), - Sum( - Identity(), - Chain( - LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), - Parallel( - Identity(), - UseContext(context=self.context, key=context_key), - UseContext(context=self.context, key=context_key), - ), - Attention( - embedding_dim=embedding_dim, - num_heads=num_heads, - key_embedding_dim=context_embedding_dim, - value_embedding_dim=context_embedding_dim, - use_bias=use_bias, - device=device, - dtype=dtype, - ), + Residual( + LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + Parallel( + Identity(), + UseContext(context=self.context, key=context_key), + UseContext(context=self.context, key=context_key), + ), + Attention( + embedding_dim=embedding_dim, + num_heads=num_heads, + key_embedding_dim=context_embedding_dim, + value_embedding_dim=context_embedding_dim, + use_bias=use_bias, + device=device, + dtype=dtype, ), ), - Sum( - Identity(), - Chain( - LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), - Linear(in_features=embedding_dim, out_features=2 * 4 * embedding_dim, device=device, dtype=dtype), - GLU(GeLU()), - Linear(in_features=4 * embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), - ), + Residual( + LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype), + Linear(in_features=embedding_dim, out_features=2 * 4 * embedding_dim, device=device, dtype=dtype), + GLU(GeLU()), + Linear(in_features=4 * embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), ), ) @@ -98,7 +89,7 @@ class StatefulFlatten(Chain): ) -class CrossAttentionBlock2d(Sum): +class CrossAttentionBlock2d(Residual): def __init__( self, channels: int, @@ -164,23 +155,20 @@ class CrossAttentionBlock2d(Sum): ) super().__init__( - Identity(), + in_block, Chain( - in_block, - Chain( - CrossAttentionBlock( - embedding_dim=channels, - context_embedding_dim=context_embedding_dim, - context_key=context_key, - num_heads=num_attention_heads, - use_bias=use_bias, - device=device, - dtype=dtype, - ) - for _ in range(num_attention_layers) - ), - out_block, + CrossAttentionBlock( + embedding_dim=channels, + context_embedding_dim=context_embedding_dim, + context_key=context_key, + num_heads=num_attention_heads, + use_bias=use_bias, + device=device, + dtype=dtype, + ) + for _ in range(num_attention_layers) ), + out_block, ) def init_context(self) -> Contexts: diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 24bd58b..99678dd 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -1,5 +1,5 @@ from refiners.fluxion.context import Contexts -from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity, Slicing +from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Slicing, Residual from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( SD1UNet, DownBlocks, @@ -92,12 +92,9 @@ class Controlnet(Passthrough): # We run the condition encoder at each step. Caching the result # is not worth it as subsequent runs take virtually no time (FG-374). self.DownBlocks[0].append( - Sum( - Identity(), - Chain( - UseContext("controlnet", f"condition_{name}"), - ConditionEncoder(device=device, dtype=dtype), - ), + Residual( + UseContext("controlnet", f"condition_{name}"), + ConditionEncoder(device=device, dtype=dtype), ), ) for residual_block in self.layers(ResidualBlock): diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index b2b40bc..e9408de 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -131,7 +131,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru y_1 = facebook_sam_h.image_encoder(image_tensor) y_2 = sam_h.image_encoder(image_tensor) - assert torch.equal(input=y_1, other=y_2) + assert torch.allclose(input=y_1, other=y_2, atol=1e-4) @torch.no_grad()