From 471ef91d1ca708f99898b9e8cf6f0924267e12a6 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 5 Feb 2024 17:10:05 +0100 Subject: [PATCH] make `__getattr__` on Module return object, not Any PyTorch chose to make it Any because they expect its users' code to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321 It is not the case for us, in Refiners having untyped code goes contrary to one of our core principles. Note that there is currently an open PR in PyTorch to return `Module | Tensor`, but in practice this is not always correct either: https://github.com/pytorch/pytorch/pull/115074 I also moved Residuals-related code from SD1 to latent_diffusion because SDXL should not depend on SD1. --- .../convert_diffusers_controlnet.py | 2 +- scripts/conversion/convert_diffusers_unet.py | 2 +- .../conversion/convert_segment_anything.py | 5 +- src/refiners/fluxion/adapters/adapter.py | 12 ++- src/refiners/fluxion/layers/chain.py | 46 +++++++++- src/refiners/fluxion/layers/module.py | 22 ++++- .../latent_diffusion/auto_encoder.py | 2 +- .../foundationals/latent_diffusion/freeu.py | 4 +- .../latent_diffusion/image_prompt.py | 4 +- .../foundationals/latent_diffusion/model.py | 6 +- .../stable_diffusion_1/controlnet.py | 25 +++--- .../stable_diffusion_1/t2i_adapter.py | 4 +- .../stable_diffusion_1/unet.py | 85 ++----------------- .../stable_diffusion_xl/t2i_adapter.py | 13 +-- .../stable_diffusion_xl/text_encoder.py | 4 +- .../stable_diffusion_xl/unet.py | 6 +- .../foundationals/latent_diffusion/unet.py | 79 +++++++++++++++++ .../segment_anything/prompt_encoder.py | 3 +- tests/adapters/test_adapter.py | 14 +-- tests/adapters/test_lora.py | 3 + tests/adapters/test_range_adapter.py | 5 +- tests/fluxion/layers/test_chain.py | 65 ++++++++++---- tests/fluxion/test_module.py | 9 +- tests/foundationals/clip/test_concepts.py | 11 +-- .../latent_diffusion/test_freeu.py | 3 +- .../segment_anything/test_sam.py | 22 +++-- 26 files changed, 288 insertions(+), 168 deletions(-) create mode 100644 src/refiners/foundationals/latent_diffusion/unet.py diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index e603bf8..d707298 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -29,7 +29,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]: ) unet = SD1UNet(in_channels=4) adapter = SD1ControlnetAdapter(unet, name="mycn").inject() - controlnet = unet.Controlnet + controlnet = adapter.controlnet condition = torch.randn(1, 3, 512, 512) adapter.set_controlnet_condition(condition=condition) diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index c762dee..f194296 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter: target.set_timestep(timestep=timestep) target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) added_cond_kwargs = {} - if source_has_time_ids: + if isinstance(target, SDXLUNet): added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index c41c2b8..7ce70ce 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -11,7 +11,7 @@ from torch import Tensor import refiners.fluxion.layers as fl from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors -from refiners.foundationals.segment_anything.image_encoder import SAMViTH +from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -136,7 +136,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping ) - embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight) + positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder) + embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight) converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore converted_source.update(rel_items) diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index 6708583..b1cd57b 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -43,14 +43,12 @@ class Adapter(Generic[T]): ), "Call the Chain constructor in the setup_adapter context." self._target = [target] - if not isinstance(self.target, fl.ContextModule): + if isinstance(target, fl.ContextModule): + assert isinstance(target, fl.ContextModule) + with target.no_parent_refresh(): + yield + else: yield - return - - _old_can_refresh_parent = target._can_refresh_parent - target._can_refresh_parent = False - yield - target._can_refresh_parent = _old_can_refresh_parent def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter: """Inject the adapter. diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 34f6325..acd2ddc 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -3,7 +3,7 @@ import re import sys import traceback from collections import defaultdict -from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload +from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload import torch from torch import Tensor, cat, device as Device, dtype as DType @@ -362,6 +362,48 @@ class Chain(ContextModule): recurse=recurse, ) + def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T: + """Access a layer of the Chain given its type. + + Example: + ```py + # same as my_chain["Linear_2"], asserts it is a Linear + my_chain.layer("Linear_2", fl.Linear) + + + # same as my_chain[3], asserts it is a Linear + my_chain.layer(3, fl.Linear) + + # probably won't work + my_chain.layer("Conv2d", fl.Linear) + + + # same as my_chain["foo"][42]["bar"], + # assuming bar is a MyType and all parents are Chains + my_chain.layer(("foo", 42, "bar"), fl.MyType) + ``` + + Args: + key: The key or path of the layer. + layer_type: The type of the layer. + + Yields: + The layer. + + Raises: + AssertionError: If the layer doesn't exist or the type is invalid. + """ + if isinstance(key, (str, int)): + r = self[key] + assert isinstance(r, layer_type), f"layer {key} is {type(r)}, not {layer_type}" + return r + if len(key) == 0: + assert isinstance(self, layer_type), f"layer is {type(self)}, not {layer_type}" + return self + if len(key) == 1: + return self.layer(key[0], layer_type) + return self.layer(key[0], Chain).layer(key[1:], layer_type) + def layers( self, layer_type: type[T], @@ -576,6 +618,7 @@ class Chain(ContextModule): require extra GPU memory since the weights are in the leaves and hence not copied. """ if hasattr(self, "_pre_structural_copy"): + assert callable(self._pre_structural_copy) self._pre_structural_copy() modules = [structural_copy(m) for m in self] @@ -586,6 +629,7 @@ class Chain(ContextModule): clone.append(module=module) if hasattr(clone, "_post_structural_copy"): + assert callable(clone._post_structural_copy) clone._post_structural_copy(self) return clone diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 3654a1c..3899f9a 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -1,11 +1,12 @@ +import contextlib import sys from collections import defaultdict from inspect import Parameter, signature from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast +from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Iterator, Sequence, TypedDict, TypeVar, cast -from torch import device as Device, dtype as DType +from torch import Tensor, device as Device, dtype as DType from torch.nn.modules.module import Module as TorchModule from refiners.fluxion.context import Context, ContextProvider @@ -29,7 +30,13 @@ class Module(TorchModule): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType] - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str) -> object: + # Note: PyTorch returns `Any` as of 2.2 and is considering + # going back to `Tensor | Module`, but the truth is it is + # impossible to type `__getattr__` correctly. + # Because PyTorch assumes its users write highly dynamic code, + # it returns Python's top type `Any`. In Refiners, static type + # checking is a core design value, hence we return `object` instead. return super().__getattr__(name=name) def __setattr__(self, name: str, value: Any) -> None: @@ -220,10 +227,19 @@ class ContextModule(Module): return super().get_path(parent=parent or self.parent, top=top) + @contextlib.contextmanager + def no_parent_refresh(self) -> Iterator[None]: + _old_can_refresh_parent = self._can_refresh_parent + self._can_refresh_parent = False + yield + self._can_refresh_parent = _old_can_refresh_parent + class WeightedModule(Module): """A module with a weight (Tensor) attribute.""" + weight: Tensor + @property def device(self) -> Device: """Return the device of the module's weight.""" diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index a53e517..f17c783 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -157,7 +157,7 @@ class Decoder(Chain): ) resnet_layers[0].insert(1, attention_layer) for _, layer in zip(range(3), resnet_layers[1:]): - channels: int = layer[-1].out_channels + channels: int = layer.layer(-1, Resnet).out_channels layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype)) super().__init__( Conv2d( diff --git a/src/refiners/foundationals/latent_diffusion/freeu.py b/src/refiners/foundationals/latent_diffusion/freeu.py index 3e2580f..f05b874 100644 --- a/src/refiners/foundationals/latent_diffusion/freeu.py +++ b/src/refiners/foundationals/latent_diffusion/freeu.py @@ -73,7 +73,7 @@ class FreeUResidualConcatenator(fl.Concatenate): class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]): def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None: assert len(backbone_scales) == len(skip_scales) - assert len(backbone_scales) <= len(target.UpBlocks) + assert len(backbone_scales) <= len(target.layer("UpBlocks", fl.Chain)) self.backbone_scales = backbone_scales self.skip_scales = skip_scales with self.setup_adapter(target): @@ -88,7 +88,7 @@ class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]): def eject(self) -> None: for n in range(len(self.backbone_scales)): - block = self.target.UpBlocks[n] + block = self.target.layer(("UpBlocks", n), fl.Chain) concat = block.ensure_find(FreeUResidualConcatenator) block.replace(concat, ResidualConcatenator(-n - 2)) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 1c35eb9..e90337c 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -307,11 +307,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): @property def image_key_projection(self) -> fl.Linear: - return self.image_cross_attention.Distribute[1].Linear + return self.image_cross_attention.layer(("Distribute", 1, "Linear"), fl.Linear) @property def image_value_projection(self) -> fl.Linear: - return self.image_cross_attention.Distribute[2].Linear + return self.image_cross_attention.layer(("Distribute", 2, "Linear"), fl.Linear) @property def scale(self) -> float: diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index b4a2f82..697debe 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -9,17 +9,15 @@ import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.solvers import Solver -T = TypeVar("T", bound="fl.Module") - TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") class LatentDiffusionModel(fl.Module, ABC): def __init__( self, - unet: fl.Module, + unet: fl.Chain, lda: LatentDiffusionAutoencoder, - clip_text_encoder: fl.Module, + clip_text_encoder: fl.Chain, solver: Solver, device: Device | str = "cpu", dtype: DType = torch.float32, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 1b8b9ff..8729bd4 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -1,5 +1,3 @@ -from typing import Iterable, cast - from torch import Tensor, device as Device, dtype as DType from refiners.fluxion.adapters.adapter import Adapter @@ -93,28 +91,31 @@ class Controlnet(Passthrough): # We run the condition encoder at each step. Caching the result # is not worth it as subsequent runs take virtually no time (FG-374). - self.DownBlocks[0].append( + + self.layer(("DownBlocks", 0), Chain).append( Residual( UseContext("controlnet", f"condition_{name}"), ConditionEncoder(device=device, dtype=dtype), ), ) for residual_block in self.layers(ResidualBlock): - chain = residual_block.Chain + chain = residual_block.layer("Chain", Chain) RangeAdapter2d( - target=chain.Conv2d_1, + target=chain.layer("Conv2d_1", Conv2d), channels=residual_block.out_channels, embedding_dim=1280, context_key=f"timestep_embedding_{name}", device=device, dtype=dtype, ).inject(chain) - for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)): - assert hasattr(block[0], "out_channels"), ( + for n, block in enumerate(self.layer("DownBlocks", DownBlocks)): + assert isinstance(block, Chain) + b0 = block[0] + assert hasattr(b0, "out_channels"), ( "The first block of every subchain in DownBlocks is expected to respond to `out_channels`," - f" {block[0]} does not." + f" {b0} does not." ) - out_channels: int = block[0].out_channels + assert isinstance(out_channels := b0.out_channels, int) block.append( Passthrough( Conv2d( @@ -123,7 +124,7 @@ class Controlnet(Passthrough): Lambda(self._store_nth_residual(n)), ) ) - self.MiddleBlock.append( + self.layer("MiddleBlock", MiddleBlock).append( Passthrough( Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype), Lambda(self._store_nth_residual(12)), @@ -166,6 +167,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): self.target.remove(self._controlnet[0]) super().eject() + @property + def controlnet(self) -> Controlnet: + return self._controlnet[0] + def init_context(self) -> Contexts: return {"controlnet": {f"condition_{self.name}": None}} diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py index c70b7ce..b6b2a97 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py @@ -25,7 +25,7 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]): def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter": for n, feat in zip(self.residual_indices, self._features, strict=True): - block = self.target.DownBlocks[n] + block = self.target.layer(("DownBlocks", n), fl.Chain) for t2i_layer in block.layers(layer_type=T2IFeatures): assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" block.insert_before_type(ResidualAccumulator, feat) @@ -33,5 +33,5 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]): def eject(self: "SD1T2IAdapter") -> None: for n, feat in zip(self.residual_indices, self._features, strict=True): - self.target.DownBlocks[n].remove(feat) + self.target.layer(("DownBlocks", n), fl.Chain).remove(feat) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 8f102a4..25e7502 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -6,6 +6,11 @@ import refiners.fluxion.layers as fl from refiners.fluxion.context import Contexts from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder +from refiners.foundationals.latent_diffusion.unet import ( + ResidualAccumulator, + ResidualBlock, + ResidualConcatenator, +) class TimestepEncoder(fl.Passthrough): @@ -22,54 +27,6 @@ class TimestepEncoder(fl.Passthrough): ) -class ResidualBlock(fl.Sum): - def __init__( - self, - in_channels: int, - out_channels: int, - num_groups: int = 32, - eps: float = 1e-5, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: - if in_channels % num_groups != 0 or out_channels % num_groups != 0: - raise ValueError("Number of input and output channels must be divisible by num_groups.") - self.in_channels = in_channels - self.out_channels = out_channels - self.num_groups = num_groups - self.eps = eps - shortcut = ( - fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype) - if in_channels != out_channels - else fl.Identity() - ) - super().__init__( - fl.Chain( - fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), - fl.SiLU(), - fl.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - device=device, - dtype=dtype, - ), - fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), - fl.SiLU(), - fl.Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - device=device, - dtype=dtype, - ), - ), - shortcut, - ) - - class CLIPLCrossAttention(CrossAttentionBlock2d): def __init__( self, @@ -205,34 +162,6 @@ class MiddleBlock(fl.Chain): ) -class ResidualAccumulator(fl.Passthrough): - def __init__(self, n: int) -> None: - self.n = n - - super().__init__( - fl.Residual( - fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n]) - ), - fl.SetContext(context="unet", key="residuals", callback=self.update), - ) - - def update(self, residuals: list[Tensor | float], x: Tensor) -> None: - residuals[self.n] = x - - -class ResidualConcatenator(fl.Chain): - def __init__(self, n: int) -> None: - self.n = n - - super().__init__( - fl.Concatenate( - fl.Identity(), - fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]), - dim=1, - ), - ) - - class SD1UNet(fl.Chain): """Stable Diffusion 1.5 U-Net. @@ -275,9 +204,9 @@ class SD1UNet(fl.Chain): ), ) for residual_block in self.layers(ResidualBlock): - chain = residual_block.Chain + chain = residual_block.layer("Chain", fl.Chain) RangeAdapter2d( - target=chain.Conv2d_1, + target=chain.layer("Conv2d_1", fl.Conv2d), channels=residual_block.out_channels, embedding_dim=1280, context_key="timestep_embedding", diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py index 3e6d8a3..7e709f1 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py @@ -2,7 +2,7 @@ from torch import Tensor import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, SDXLUNet from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures @@ -31,18 +31,19 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]): # Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below) for n, feat in zip(self.residual_indices, self._features, strict=False): - block = self.target.DownBlocks[n] + block = self.target.layer(("DownBlocks", n), fl.Chain) sanity_check_t2i(block) block.insert_before_type(ResidualAccumulator, feat) # Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append - sanity_check_t2i(self.target.MiddleBlock) - self.target.MiddleBlock.append(self._features[-1]) + mid_block = self.target.layer("MiddleBlock", MiddleBlock) + sanity_check_t2i(mid_block) + mid_block.append(self._features[-1]) return super().inject(parent) def eject(self: "SDXLT2IAdapter") -> None: # See `inject` re: `strict=False` for n, feat in zip(self.residual_indices, self._features, strict=False): - self.target.DownBlocks[n].remove(feat) - self.target.MiddleBlock.remove(self._features[-1]) + self.target.layer(("DownBlocks", n), fl.Chain).remove(feat) + self.target.layer("MiddleBlock", MiddleBlock).remove(self._features[-1]) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 72a6245..cd45e5a 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -72,7 +72,9 @@ class DoubleTextEncoder(fl.Chain): fl.Parallel(text_encoder_l[:-2], text_encoder_g), fl.Lambda(func=self.concatenate_embeddings), ) - TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel) + TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject( + parent=self.layer("Parallel", fl.Parallel) + ) def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: return super().__call__(text) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 6514442..9663fa6 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -10,7 +10,7 @@ from refiners.foundationals.latent_diffusion.range_adapter import ( RangeEncoder, compute_sinusoidal_embedding, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( +from refiners.foundationals.latent_diffusion.unet import ( ResidualAccumulator, ResidualBlock, ResidualConcatenator, @@ -267,9 +267,9 @@ class SDXLUNet(fl.Chain): OutputBlock(device=device, dtype=dtype), ) for residual_block in self.layers(ResidualBlock): - chain = residual_block.Chain + chain = residual_block.layer("Chain", fl.Chain) RangeAdapter2d( - target=chain.Conv2d_1, + target=chain.layer("Conv2d_1", fl.Conv2d), channels=residual_block.out_channels, embedding_dim=1280, context_key="timestep_embedding", diff --git a/src/refiners/foundationals/latent_diffusion/unet.py b/src/refiners/foundationals/latent_diffusion/unet.py new file mode 100644 index 0000000..fc7198d --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/unet.py @@ -0,0 +1,79 @@ +from torch import Tensor, device as Device, dtype as DType + +import refiners.fluxion.layers as fl + + +class ResidualBlock(fl.Sum): + def __init__( + self, + in_channels: int, + out_channels: int, + num_groups: int = 32, + eps: float = 1e-5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + if in_channels % num_groups != 0 or out_channels % num_groups != 0: + raise ValueError("Number of input and output channels must be divisible by num_groups.") + self.in_channels = in_channels + self.out_channels = out_channels + self.num_groups = num_groups + self.eps = eps + shortcut = ( + fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype) + if in_channels != out_channels + else fl.Identity() + ) + super().__init__( + fl.Chain( + fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), + fl.SiLU(), + fl.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + device=device, + dtype=dtype, + ), + fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), + fl.SiLU(), + fl.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + device=device, + dtype=dtype, + ), + ), + shortcut, + ) + + +class ResidualAccumulator(fl.Passthrough): + def __init__(self, n: int) -> None: + self.n = n + + super().__init__( + fl.Residual( + fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n]) + ), + fl.SetContext(context="unet", key="residuals", callback=self.update), + ) + + def update(self, residuals: list[Tensor | float], x: Tensor) -> None: + residuals[self.n] = x + + +class ResidualConcatenator(fl.Chain): + def __init__(self, n: int) -> None: + self.n = n + + super().__init__( + fl.Concatenate( + fl.Identity(), + fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]), + dim=1, + ), + ) diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py index edcd116..f062ecb 100644 --- a/src/refiners/foundationals/segment_anything/prompt_encoder.py +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -188,6 +188,7 @@ class MaskEncoder(fl.Chain): def get_no_mask_dense_embedding( self, image_embedding_size: tuple[int, int], batch_size: int = 1 ) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]: - return self.no_mask_embedding.reshape(1, -1, 1, 1).expand( + no_mask_embedding = cast(Tensor, self.no_mask_embedding) + return no_mask_embedding.reshape(1, -1, 1, 1).expand( batch_size, -1, image_embedding_size[0], image_embedding_size[1] ) diff --git a/tests/adapters/test_adapter.py b/tests/adapters/test_adapter.py index 8049f97..38b1bd0 100644 --- a/tests/adapters/test_adapter.py +++ b/tests/adapters/test_adapter.py @@ -22,8 +22,8 @@ def chain() -> Chain: def test_weighted_module_adapter_insertion(chain: Chain): - parent = chain.Chain - adaptee = parent.Linear + parent = chain.layer("Chain", Chain) + adaptee = parent.layer("Linear", Linear) adapter = DummyLinearAdapter(adaptee).inject(parent) @@ -39,7 +39,7 @@ def test_weighted_module_adapter_insertion(chain: Chain): def test_chain_adapter_insertion(chain: Chain): parent = chain - adaptee = parent.Chain + adaptee = parent.layer("Chain", Chain) adapter = DummyChainAdapter(adaptee) assert adaptee.parent == parent @@ -58,20 +58,20 @@ def test_chain_adapter_insertion(chain: Chain): def test_weighted_module_adapter_structural_copy(chain: Chain): - parent = chain.Chain - adaptee = parent.Linear + parent = chain.layer("Chain", Chain) + adaptee = parent.layer("Linear", Linear) DummyLinearAdapter(adaptee).inject(parent) clone = chain.structural_copy() - cloned_adapter = clone.Chain.DummyLinearAdapter + cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter) assert cloned_adapter.parent == clone.Chain assert cloned_adapter.target == adaptee def test_chain_adapter_structural_copy(chain: Chain): # Chain adapters cannot be copied by default. - adapter = DummyChainAdapter(chain.Chain).inject() + adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject() with pytest.raises(RuntimeError): chain.structural_copy() diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py index d298aa1..8666982 100644 --- a/tests/adapters/test_lora.py +++ b/tests/adapters/test_lora.py @@ -27,6 +27,7 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None: assert conv_lora.scale == 1.0 assert conv_lora.in_channels == conv_lora.down.in_channels == 16 assert conv_lora.out_channels == conv_lora.up.out_channels == 8 + assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d) assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1) # padding is set so the spatial dimensions are preserved assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1) @@ -39,10 +40,12 @@ def test_scale_setter(lora: LinearLora) -> None: def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None: + assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear) new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight) x = torch.randn(1, 320) assert torch.allclose(lora(x), new_lora(x)) + assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d) new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight) x = torch.randn(1, 16, 64, 64) assert torch.allclose(conv_lora(x), new_conv_lora(x)) diff --git a/tests/adapters/test_range_adapter.py b/tests/adapters/test_range_adapter.py index 90e3bdc..f42cb99 100644 --- a/tests/adapters/test_range_adapter.py +++ b/tests/adapters/test_range_adapter.py @@ -15,8 +15,9 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG- dtype = torch.float64 chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype)) - adaptee = chain.RangeEncoder.Linear_1 - adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder) + range_encoder = chain.layer("RangeEncoder", RangeEncoder) + adaptee = range_encoder.layer("Linear_1", Linear) + adapter = DummyLinearAdapter(adaptee).inject(range_encoder) assert adapter.parent == chain.RangeEncoder diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 5e87812..03e861f 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -30,8 +30,9 @@ def test_chain_getitem_accessor() -> None: def test_chain_find_parent(): chain = fl.Chain(fl.Chain(fl.Linear(1, 1))) + subchain = chain.layer("Chain", fl.Chain) - assert chain.find_parent(chain.Chain.Linear) == chain.Chain + assert chain.find_parent(subchain.layer("Linear", fl.Linear)) == subchain assert chain.find_parent(fl.Linear(1, 1)) is None @@ -64,17 +65,20 @@ def test_chain_walk() -> None: fl.Chain(), ) - assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)] + sum_ = chain.layer("Sum", fl.Sum) + sum_chain = sum_.layer("Chain", fl.Chain) + + assert list(chain.walk()) == [(sum_, chain), (chain.Chain, chain)] assert list(chain.walk(fl.Linear)) == [ - (chain.Sum.Chain.Linear, chain.Sum.Chain), - (chain.Sum.Linear, chain.Sum), + (sum_chain.Linear, sum_chain), + (sum_.Linear, sum_), ] assert list(chain.walk(recurse=True)) == [ - (chain.Sum, chain), - (chain.Sum.Chain, chain.Sum), - (chain.Sum.Chain.Linear, chain.Sum.Chain), - (chain.Sum.Linear, chain.Sum), + (sum_, chain), + (sum_chain, sum_), + (sum_chain.Linear, sum_chain), + (sum_.Linear, sum_), (chain.Chain, chain), ] @@ -98,6 +102,29 @@ def test_chain_walk_stop_iteration() -> None: assert len(list(chain.walk(predicate))) == 1 +def test_chain_layer() -> None: + chain = fl.Chain( + fl.Sum(fl.Chain(), fl.Chain()), + ) + + sum_ = chain.layer(0, fl.Sum) + assert chain.layer("Sum", fl.Sum) == sum_ + assert chain.layer("Sum", fl.Chain) == sum_ + + chain_2 = chain.layer((0, 1), fl.Chain) + assert chain.layer((0, 1)) == chain_2 + assert chain.layer((0, "Chain_2"), fl.Chain) == chain_2 + assert chain.layer(("Sum", "Chain_2"), fl.Chain) == chain_2 + + assert chain.layer((), fl.Chain) == chain + + with pytest.raises(AssertionError): + chain.layer((0, 1), fl.Sum) + + with pytest.raises(AssertionError): + chain.layer((), fl.Sum) + + def test_chain_layers() -> None: chain = fl.Chain( fl.Chain(fl.Chain(fl.Chain())), @@ -214,11 +241,12 @@ def test_chain_replace() -> None: fl.Linear(1, 1), fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), ) + subchain = chain.layer("Chain", fl.Chain) - assert isinstance(chain.Chain[1], fl.Linear) - chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1)) + assert isinstance(subchain[1], fl.Linear) + subchain.replace(subchain[1], fl.Conv2d(1, 1, 1)) assert len(chain) == 3 - assert isinstance(chain.Chain[1], fl.Conv2d) + assert isinstance(subchain[1], fl.Conv2d) def test_chain_structural_copy() -> None: @@ -236,15 +264,18 @@ def test_chain_structural_copy() -> None: m2 = m.structural_copy() - assert m.Linear == m2.Linear - assert m.Sum.Linear_1 == m2.Sum.Linear_1 - assert m.Sum.Linear_2 == m2.Sum.Linear_2 + m_sum = m.layer("Sum", fl.Sum) + m2_sum = m2.layer("Sum", fl.Sum) - assert m.Sum != m2.Sum + assert m.Linear == m2.Linear + assert m_sum.Linear_1 == m2_sum.Linear_1 + assert m_sum.Linear_2 == m2_sum.Linear_2 + + assert m_sum != m2_sum assert m != m2 - assert m.Sum.parent == m - assert m2.Sum.parent == m2 + assert m_sum.parent == m + assert m2_sum.parent == m2 y2 = m2(x) assert y2.shape == (7, 12) diff --git a/tests/fluxion/test_module.py b/tests/fluxion/test_module.py index bf87650..fc82398 100644 --- a/tests/fluxion/test_module.py +++ b/tests/fluxion/test_module.py @@ -10,9 +10,12 @@ def test_module_get_path() -> None: fl.Sum(), ) - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" - assert chain.Sum_1.get_path() == "Chain.Sum_1" + sum_1 = chain.layer("Sum_1", fl.Sum) + linear_2 = sum_1.layer("Linear_2", fl.Linear) + + assert linear_2.get_path(parent=sum_1) == "Chain.Sum_1.Linear_2" + assert linear_2.get_path(parent=sum_1, top=sum_1) == "Sum.Linear_2" + assert sum_1.get_path() == "Chain.Sum_1" def test_module_basic_attributes() -> None: diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index 9ab1e7e..debeee5 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -85,13 +85,14 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch. def test_tokenizer_with_special_character(): - clip_tokenizer = fl.Chain(CLIPTokenizer()) - token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer) - new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42 + clip_tokenizer_chain = fl.Chain(CLIPTokenizer()) + original_clip_tokenizer = clip_tokenizer_chain.layer("CLIPTokenizer", CLIPTokenizer) + token_extender = TokenExtender(original_clip_tokenizer) + new_token_id = max(original_clip_tokenizer.token_to_id_mapping.values()) + 42 token_extender.add_token("*", new_token_id) - token_extender.inject(clip_tokenizer) + token_extender.inject(clip_tokenizer_chain) - adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer) + adapted_clip_tokenizer = clip_tokenizer_chain.ensure_find(CLIPTokenizer) assert torch.allclose( adapted_clip_tokenizer.encode("*"), diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py index 3e4553e..487f369 100644 --- a/tests/foundationals/latent_diffusion/test_freeu.py +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -4,6 +4,7 @@ import pytest import torch from refiners.fluxion import manual_seed +from refiners.fluxion.layers import Chain from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter @@ -33,7 +34,7 @@ def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None: def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None: - num_blocks = len(unet.UpBlocks) + num_blocks = len(unet.layer("UpBlocks", Chain)) with pytest.raises(AssertionError): SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1)) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 1a62217..e8afa41 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -16,10 +16,11 @@ from tests.foundationals.segment_anything.utils import ( ) from torch import Tensor +import refiners.fluxion.layers as fl from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad -from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention +from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer @@ -104,17 +105,22 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: manual_seed(seed=0) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) - attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore + attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) refiners_attention = FusedSelfAttention( embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device ) - refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore - refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore - refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore - refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore - refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w - refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h + + rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention) + linear_1 = refiners_attention.layer("Linear_1", fl.Linear) + linear_2 = refiners_attention.layer("Linear_2", fl.Linear) + + linear_1.weight = attention.qkv.weight + linear_1.bias = attention.qkv.bias + linear_2.weight = attention.proj.weight + linear_2.bias = attention.proj.bias + rpa.horizontal_embedding = attention.rel_pos_w + rpa.vertical_embedding = attention.rel_pos_h y_1 = attention(x) assert y_1.shape == x.shape