diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index e603bf8..d707298 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -29,7 +29,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]: ) unet = SD1UNet(in_channels=4) adapter = SD1ControlnetAdapter(unet, name="mycn").inject() - controlnet = unet.Controlnet + controlnet = adapter.controlnet condition = torch.randn(1, 3, 512, 512) adapter.set_controlnet_condition(condition=condition) diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index c762dee..f194296 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter: target.set_timestep(timestep=timestep) target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) added_cond_kwargs = {} - if source_has_time_ids: + if isinstance(target, SDXLUNet): added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index c41c2b8..7ce70ce 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -11,7 +11,7 @@ from torch import Tensor import refiners.fluxion.layers as fl from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors -from refiners.foundationals.segment_anything.image_encoder import SAMViTH +from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -136,7 +136,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]: source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping ) - embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight) + positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder) + embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight) converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore converted_source.update(rel_items) diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index 6708583..b1cd57b 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -43,14 +43,12 @@ class Adapter(Generic[T]): ), "Call the Chain constructor in the setup_adapter context." self._target = [target] - if not isinstance(self.target, fl.ContextModule): + if isinstance(target, fl.ContextModule): + assert isinstance(target, fl.ContextModule) + with target.no_parent_refresh(): + yield + else: yield - return - - _old_can_refresh_parent = target._can_refresh_parent - target._can_refresh_parent = False - yield - target._can_refresh_parent = _old_can_refresh_parent def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter: """Inject the adapter. diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 34f6325..acd2ddc 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -3,7 +3,7 @@ import re import sys import traceback from collections import defaultdict -from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload +from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload import torch from torch import Tensor, cat, device as Device, dtype as DType @@ -362,6 +362,48 @@ class Chain(ContextModule): recurse=recurse, ) + def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T: + """Access a layer of the Chain given its type. + + Example: + ```py + # same as my_chain["Linear_2"], asserts it is a Linear + my_chain.layer("Linear_2", fl.Linear) + + + # same as my_chain[3], asserts it is a Linear + my_chain.layer(3, fl.Linear) + + # probably won't work + my_chain.layer("Conv2d", fl.Linear) + + + # same as my_chain["foo"][42]["bar"], + # assuming bar is a MyType and all parents are Chains + my_chain.layer(("foo", 42, "bar"), fl.MyType) + ``` + + Args: + key: The key or path of the layer. + layer_type: The type of the layer. + + Yields: + The layer. + + Raises: + AssertionError: If the layer doesn't exist or the type is invalid. + """ + if isinstance(key, (str, int)): + r = self[key] + assert isinstance(r, layer_type), f"layer {key} is {type(r)}, not {layer_type}" + return r + if len(key) == 0: + assert isinstance(self, layer_type), f"layer is {type(self)}, not {layer_type}" + return self + if len(key) == 1: + return self.layer(key[0], layer_type) + return self.layer(key[0], Chain).layer(key[1:], layer_type) + def layers( self, layer_type: type[T], @@ -576,6 +618,7 @@ class Chain(ContextModule): require extra GPU memory since the weights are in the leaves and hence not copied. """ if hasattr(self, "_pre_structural_copy"): + assert callable(self._pre_structural_copy) self._pre_structural_copy() modules = [structural_copy(m) for m in self] @@ -586,6 +629,7 @@ class Chain(ContextModule): clone.append(module=module) if hasattr(clone, "_post_structural_copy"): + assert callable(clone._post_structural_copy) clone._post_structural_copy(self) return clone diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 3654a1c..3899f9a 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -1,11 +1,12 @@ +import contextlib import sys from collections import defaultdict from inspect import Parameter, signature from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast +from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Iterator, Sequence, TypedDict, TypeVar, cast -from torch import device as Device, dtype as DType +from torch import Tensor, device as Device, dtype as DType from torch.nn.modules.module import Module as TorchModule from refiners.fluxion.context import Context, ContextProvider @@ -29,7 +30,13 @@ class Module(TorchModule): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType] - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str) -> object: + # Note: PyTorch returns `Any` as of 2.2 and is considering + # going back to `Tensor | Module`, but the truth is it is + # impossible to type `__getattr__` correctly. + # Because PyTorch assumes its users write highly dynamic code, + # it returns Python's top type `Any`. In Refiners, static type + # checking is a core design value, hence we return `object` instead. return super().__getattr__(name=name) def __setattr__(self, name: str, value: Any) -> None: @@ -220,10 +227,19 @@ class ContextModule(Module): return super().get_path(parent=parent or self.parent, top=top) + @contextlib.contextmanager + def no_parent_refresh(self) -> Iterator[None]: + _old_can_refresh_parent = self._can_refresh_parent + self._can_refresh_parent = False + yield + self._can_refresh_parent = _old_can_refresh_parent + class WeightedModule(Module): """A module with a weight (Tensor) attribute.""" + weight: Tensor + @property def device(self) -> Device: """Return the device of the module's weight.""" diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index a53e517..f17c783 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -157,7 +157,7 @@ class Decoder(Chain): ) resnet_layers[0].insert(1, attention_layer) for _, layer in zip(range(3), resnet_layers[1:]): - channels: int = layer[-1].out_channels + channels: int = layer.layer(-1, Resnet).out_channels layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype)) super().__init__( Conv2d( diff --git a/src/refiners/foundationals/latent_diffusion/freeu.py b/src/refiners/foundationals/latent_diffusion/freeu.py index 3e2580f..f05b874 100644 --- a/src/refiners/foundationals/latent_diffusion/freeu.py +++ b/src/refiners/foundationals/latent_diffusion/freeu.py @@ -73,7 +73,7 @@ class FreeUResidualConcatenator(fl.Concatenate): class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]): def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None: assert len(backbone_scales) == len(skip_scales) - assert len(backbone_scales) <= len(target.UpBlocks) + assert len(backbone_scales) <= len(target.layer("UpBlocks", fl.Chain)) self.backbone_scales = backbone_scales self.skip_scales = skip_scales with self.setup_adapter(target): @@ -88,7 +88,7 @@ class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]): def eject(self) -> None: for n in range(len(self.backbone_scales)): - block = self.target.UpBlocks[n] + block = self.target.layer(("UpBlocks", n), fl.Chain) concat = block.ensure_find(FreeUResidualConcatenator) block.replace(concat, ResidualConcatenator(-n - 2)) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 1c35eb9..e90337c 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -307,11 +307,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): @property def image_key_projection(self) -> fl.Linear: - return self.image_cross_attention.Distribute[1].Linear + return self.image_cross_attention.layer(("Distribute", 1, "Linear"), fl.Linear) @property def image_value_projection(self) -> fl.Linear: - return self.image_cross_attention.Distribute[2].Linear + return self.image_cross_attention.layer(("Distribute", 2, "Linear"), fl.Linear) @property def scale(self) -> float: diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index b4a2f82..697debe 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -9,17 +9,15 @@ import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.solvers import Solver -T = TypeVar("T", bound="fl.Module") - TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") class LatentDiffusionModel(fl.Module, ABC): def __init__( self, - unet: fl.Module, + unet: fl.Chain, lda: LatentDiffusionAutoencoder, - clip_text_encoder: fl.Module, + clip_text_encoder: fl.Chain, solver: Solver, device: Device | str = "cpu", dtype: DType = torch.float32, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 1b8b9ff..8729bd4 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -1,5 +1,3 @@ -from typing import Iterable, cast - from torch import Tensor, device as Device, dtype as DType from refiners.fluxion.adapters.adapter import Adapter @@ -93,28 +91,31 @@ class Controlnet(Passthrough): # We run the condition encoder at each step. Caching the result # is not worth it as subsequent runs take virtually no time (FG-374). - self.DownBlocks[0].append( + + self.layer(("DownBlocks", 0), Chain).append( Residual( UseContext("controlnet", f"condition_{name}"), ConditionEncoder(device=device, dtype=dtype), ), ) for residual_block in self.layers(ResidualBlock): - chain = residual_block.Chain + chain = residual_block.layer("Chain", Chain) RangeAdapter2d( - target=chain.Conv2d_1, + target=chain.layer("Conv2d_1", Conv2d), channels=residual_block.out_channels, embedding_dim=1280, context_key=f"timestep_embedding_{name}", device=device, dtype=dtype, ).inject(chain) - for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)): - assert hasattr(block[0], "out_channels"), ( + for n, block in enumerate(self.layer("DownBlocks", DownBlocks)): + assert isinstance(block, Chain) + b0 = block[0] + assert hasattr(b0, "out_channels"), ( "The first block of every subchain in DownBlocks is expected to respond to `out_channels`," - f" {block[0]} does not." + f" {b0} does not." ) - out_channels: int = block[0].out_channels + assert isinstance(out_channels := b0.out_channels, int) block.append( Passthrough( Conv2d( @@ -123,7 +124,7 @@ class Controlnet(Passthrough): Lambda(self._store_nth_residual(n)), ) ) - self.MiddleBlock.append( + self.layer("MiddleBlock", MiddleBlock).append( Passthrough( Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype), Lambda(self._store_nth_residual(12)), @@ -166,6 +167,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): self.target.remove(self._controlnet[0]) super().eject() + @property + def controlnet(self) -> Controlnet: + return self._controlnet[0] + def init_context(self) -> Contexts: return {"controlnet": {f"condition_{self.name}": None}} diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py index c70b7ce..b6b2a97 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py @@ -25,7 +25,7 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]): def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter": for n, feat in zip(self.residual_indices, self._features, strict=True): - block = self.target.DownBlocks[n] + block = self.target.layer(("DownBlocks", n), fl.Chain) for t2i_layer in block.layers(layer_type=T2IFeatures): assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" block.insert_before_type(ResidualAccumulator, feat) @@ -33,5 +33,5 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]): def eject(self: "SD1T2IAdapter") -> None: for n, feat in zip(self.residual_indices, self._features, strict=True): - self.target.DownBlocks[n].remove(feat) + self.target.layer(("DownBlocks", n), fl.Chain).remove(feat) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 8f102a4..25e7502 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -6,6 +6,11 @@ import refiners.fluxion.layers as fl from refiners.fluxion.context import Contexts from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder +from refiners.foundationals.latent_diffusion.unet import ( + ResidualAccumulator, + ResidualBlock, + ResidualConcatenator, +) class TimestepEncoder(fl.Passthrough): @@ -22,54 +27,6 @@ class TimestepEncoder(fl.Passthrough): ) -class ResidualBlock(fl.Sum): - def __init__( - self, - in_channels: int, - out_channels: int, - num_groups: int = 32, - eps: float = 1e-5, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: - if in_channels % num_groups != 0 or out_channels % num_groups != 0: - raise ValueError("Number of input and output channels must be divisible by num_groups.") - self.in_channels = in_channels - self.out_channels = out_channels - self.num_groups = num_groups - self.eps = eps - shortcut = ( - fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype) - if in_channels != out_channels - else fl.Identity() - ) - super().__init__( - fl.Chain( - fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), - fl.SiLU(), - fl.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - device=device, - dtype=dtype, - ), - fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), - fl.SiLU(), - fl.Conv2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - device=device, - dtype=dtype, - ), - ), - shortcut, - ) - - class CLIPLCrossAttention(CrossAttentionBlock2d): def __init__( self, @@ -205,34 +162,6 @@ class MiddleBlock(fl.Chain): ) -class ResidualAccumulator(fl.Passthrough): - def __init__(self, n: int) -> None: - self.n = n - - super().__init__( - fl.Residual( - fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n]) - ), - fl.SetContext(context="unet", key="residuals", callback=self.update), - ) - - def update(self, residuals: list[Tensor | float], x: Tensor) -> None: - residuals[self.n] = x - - -class ResidualConcatenator(fl.Chain): - def __init__(self, n: int) -> None: - self.n = n - - super().__init__( - fl.Concatenate( - fl.Identity(), - fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]), - dim=1, - ), - ) - - class SD1UNet(fl.Chain): """Stable Diffusion 1.5 U-Net. @@ -275,9 +204,9 @@ class SD1UNet(fl.Chain): ), ) for residual_block in self.layers(ResidualBlock): - chain = residual_block.Chain + chain = residual_block.layer("Chain", fl.Chain) RangeAdapter2d( - target=chain.Conv2d_1, + target=chain.layer("Conv2d_1", fl.Conv2d), channels=residual_block.out_channels, embedding_dim=1280, context_key="timestep_embedding", diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py index 3e6d8a3..7e709f1 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py @@ -2,7 +2,7 @@ from torch import Tensor import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, SDXLUNet from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures @@ -31,18 +31,19 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]): # Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below) for n, feat in zip(self.residual_indices, self._features, strict=False): - block = self.target.DownBlocks[n] + block = self.target.layer(("DownBlocks", n), fl.Chain) sanity_check_t2i(block) block.insert_before_type(ResidualAccumulator, feat) # Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append - sanity_check_t2i(self.target.MiddleBlock) - self.target.MiddleBlock.append(self._features[-1]) + mid_block = self.target.layer("MiddleBlock", MiddleBlock) + sanity_check_t2i(mid_block) + mid_block.append(self._features[-1]) return super().inject(parent) def eject(self: "SDXLT2IAdapter") -> None: # See `inject` re: `strict=False` for n, feat in zip(self.residual_indices, self._features, strict=False): - self.target.DownBlocks[n].remove(feat) - self.target.MiddleBlock.remove(self._features[-1]) + self.target.layer(("DownBlocks", n), fl.Chain).remove(feat) + self.target.layer("MiddleBlock", MiddleBlock).remove(self._features[-1]) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 72a6245..cd45e5a 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -72,7 +72,9 @@ class DoubleTextEncoder(fl.Chain): fl.Parallel(text_encoder_l[:-2], text_encoder_g), fl.Lambda(func=self.concatenate_embeddings), ) - TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel) + TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject( + parent=self.layer("Parallel", fl.Parallel) + ) def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: return super().__call__(text) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 6514442..9663fa6 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -10,7 +10,7 @@ from refiners.foundationals.latent_diffusion.range_adapter import ( RangeEncoder, compute_sinusoidal_embedding, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( +from refiners.foundationals.latent_diffusion.unet import ( ResidualAccumulator, ResidualBlock, ResidualConcatenator, @@ -267,9 +267,9 @@ class SDXLUNet(fl.Chain): OutputBlock(device=device, dtype=dtype), ) for residual_block in self.layers(ResidualBlock): - chain = residual_block.Chain + chain = residual_block.layer("Chain", fl.Chain) RangeAdapter2d( - target=chain.Conv2d_1, + target=chain.layer("Conv2d_1", fl.Conv2d), channels=residual_block.out_channels, embedding_dim=1280, context_key="timestep_embedding", diff --git a/src/refiners/foundationals/latent_diffusion/unet.py b/src/refiners/foundationals/latent_diffusion/unet.py new file mode 100644 index 0000000..fc7198d --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/unet.py @@ -0,0 +1,79 @@ +from torch import Tensor, device as Device, dtype as DType + +import refiners.fluxion.layers as fl + + +class ResidualBlock(fl.Sum): + def __init__( + self, + in_channels: int, + out_channels: int, + num_groups: int = 32, + eps: float = 1e-5, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + if in_channels % num_groups != 0 or out_channels % num_groups != 0: + raise ValueError("Number of input and output channels must be divisible by num_groups.") + self.in_channels = in_channels + self.out_channels = out_channels + self.num_groups = num_groups + self.eps = eps + shortcut = ( + fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype) + if in_channels != out_channels + else fl.Identity() + ) + super().__init__( + fl.Chain( + fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), + fl.SiLU(), + fl.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + device=device, + dtype=dtype, + ), + fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype), + fl.SiLU(), + fl.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + device=device, + dtype=dtype, + ), + ), + shortcut, + ) + + +class ResidualAccumulator(fl.Passthrough): + def __init__(self, n: int) -> None: + self.n = n + + super().__init__( + fl.Residual( + fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n]) + ), + fl.SetContext(context="unet", key="residuals", callback=self.update), + ) + + def update(self, residuals: list[Tensor | float], x: Tensor) -> None: + residuals[self.n] = x + + +class ResidualConcatenator(fl.Chain): + def __init__(self, n: int) -> None: + self.n = n + + super().__init__( + fl.Concatenate( + fl.Identity(), + fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]), + dim=1, + ), + ) diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py index edcd116..f062ecb 100644 --- a/src/refiners/foundationals/segment_anything/prompt_encoder.py +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -188,6 +188,7 @@ class MaskEncoder(fl.Chain): def get_no_mask_dense_embedding( self, image_embedding_size: tuple[int, int], batch_size: int = 1 ) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]: - return self.no_mask_embedding.reshape(1, -1, 1, 1).expand( + no_mask_embedding = cast(Tensor, self.no_mask_embedding) + return no_mask_embedding.reshape(1, -1, 1, 1).expand( batch_size, -1, image_embedding_size[0], image_embedding_size[1] ) diff --git a/tests/adapters/test_adapter.py b/tests/adapters/test_adapter.py index 8049f97..38b1bd0 100644 --- a/tests/adapters/test_adapter.py +++ b/tests/adapters/test_adapter.py @@ -22,8 +22,8 @@ def chain() -> Chain: def test_weighted_module_adapter_insertion(chain: Chain): - parent = chain.Chain - adaptee = parent.Linear + parent = chain.layer("Chain", Chain) + adaptee = parent.layer("Linear", Linear) adapter = DummyLinearAdapter(adaptee).inject(parent) @@ -39,7 +39,7 @@ def test_weighted_module_adapter_insertion(chain: Chain): def test_chain_adapter_insertion(chain: Chain): parent = chain - adaptee = parent.Chain + adaptee = parent.layer("Chain", Chain) adapter = DummyChainAdapter(adaptee) assert adaptee.parent == parent @@ -58,20 +58,20 @@ def test_chain_adapter_insertion(chain: Chain): def test_weighted_module_adapter_structural_copy(chain: Chain): - parent = chain.Chain - adaptee = parent.Linear + parent = chain.layer("Chain", Chain) + adaptee = parent.layer("Linear", Linear) DummyLinearAdapter(adaptee).inject(parent) clone = chain.structural_copy() - cloned_adapter = clone.Chain.DummyLinearAdapter + cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter) assert cloned_adapter.parent == clone.Chain assert cloned_adapter.target == adaptee def test_chain_adapter_structural_copy(chain: Chain): # Chain adapters cannot be copied by default. - adapter = DummyChainAdapter(chain.Chain).inject() + adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject() with pytest.raises(RuntimeError): chain.structural_copy() diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py index d298aa1..8666982 100644 --- a/tests/adapters/test_lora.py +++ b/tests/adapters/test_lora.py @@ -27,6 +27,7 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None: assert conv_lora.scale == 1.0 assert conv_lora.in_channels == conv_lora.down.in_channels == 16 assert conv_lora.out_channels == conv_lora.up.out_channels == 8 + assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d) assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1) # padding is set so the spatial dimensions are preserved assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1) @@ -39,10 +40,12 @@ def test_scale_setter(lora: LinearLora) -> None: def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None: + assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear) new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight) x = torch.randn(1, 320) assert torch.allclose(lora(x), new_lora(x)) + assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d) new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight) x = torch.randn(1, 16, 64, 64) assert torch.allclose(conv_lora(x), new_conv_lora(x)) diff --git a/tests/adapters/test_range_adapter.py b/tests/adapters/test_range_adapter.py index 90e3bdc..f42cb99 100644 --- a/tests/adapters/test_range_adapter.py +++ b/tests/adapters/test_range_adapter.py @@ -15,8 +15,9 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG- dtype = torch.float64 chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype)) - adaptee = chain.RangeEncoder.Linear_1 - adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder) + range_encoder = chain.layer("RangeEncoder", RangeEncoder) + adaptee = range_encoder.layer("Linear_1", Linear) + adapter = DummyLinearAdapter(adaptee).inject(range_encoder) assert adapter.parent == chain.RangeEncoder diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 5e87812..03e861f 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -30,8 +30,9 @@ def test_chain_getitem_accessor() -> None: def test_chain_find_parent(): chain = fl.Chain(fl.Chain(fl.Linear(1, 1))) + subchain = chain.layer("Chain", fl.Chain) - assert chain.find_parent(chain.Chain.Linear) == chain.Chain + assert chain.find_parent(subchain.layer("Linear", fl.Linear)) == subchain assert chain.find_parent(fl.Linear(1, 1)) is None @@ -64,17 +65,20 @@ def test_chain_walk() -> None: fl.Chain(), ) - assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)] + sum_ = chain.layer("Sum", fl.Sum) + sum_chain = sum_.layer("Chain", fl.Chain) + + assert list(chain.walk()) == [(sum_, chain), (chain.Chain, chain)] assert list(chain.walk(fl.Linear)) == [ - (chain.Sum.Chain.Linear, chain.Sum.Chain), - (chain.Sum.Linear, chain.Sum), + (sum_chain.Linear, sum_chain), + (sum_.Linear, sum_), ] assert list(chain.walk(recurse=True)) == [ - (chain.Sum, chain), - (chain.Sum.Chain, chain.Sum), - (chain.Sum.Chain.Linear, chain.Sum.Chain), - (chain.Sum.Linear, chain.Sum), + (sum_, chain), + (sum_chain, sum_), + (sum_chain.Linear, sum_chain), + (sum_.Linear, sum_), (chain.Chain, chain), ] @@ -98,6 +102,29 @@ def test_chain_walk_stop_iteration() -> None: assert len(list(chain.walk(predicate))) == 1 +def test_chain_layer() -> None: + chain = fl.Chain( + fl.Sum(fl.Chain(), fl.Chain()), + ) + + sum_ = chain.layer(0, fl.Sum) + assert chain.layer("Sum", fl.Sum) == sum_ + assert chain.layer("Sum", fl.Chain) == sum_ + + chain_2 = chain.layer((0, 1), fl.Chain) + assert chain.layer((0, 1)) == chain_2 + assert chain.layer((0, "Chain_2"), fl.Chain) == chain_2 + assert chain.layer(("Sum", "Chain_2"), fl.Chain) == chain_2 + + assert chain.layer((), fl.Chain) == chain + + with pytest.raises(AssertionError): + chain.layer((0, 1), fl.Sum) + + with pytest.raises(AssertionError): + chain.layer((), fl.Sum) + + def test_chain_layers() -> None: chain = fl.Chain( fl.Chain(fl.Chain(fl.Chain())), @@ -214,11 +241,12 @@ def test_chain_replace() -> None: fl.Linear(1, 1), fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), ) + subchain = chain.layer("Chain", fl.Chain) - assert isinstance(chain.Chain[1], fl.Linear) - chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1)) + assert isinstance(subchain[1], fl.Linear) + subchain.replace(subchain[1], fl.Conv2d(1, 1, 1)) assert len(chain) == 3 - assert isinstance(chain.Chain[1], fl.Conv2d) + assert isinstance(subchain[1], fl.Conv2d) def test_chain_structural_copy() -> None: @@ -236,15 +264,18 @@ def test_chain_structural_copy() -> None: m2 = m.structural_copy() - assert m.Linear == m2.Linear - assert m.Sum.Linear_1 == m2.Sum.Linear_1 - assert m.Sum.Linear_2 == m2.Sum.Linear_2 + m_sum = m.layer("Sum", fl.Sum) + m2_sum = m2.layer("Sum", fl.Sum) - assert m.Sum != m2.Sum + assert m.Linear == m2.Linear + assert m_sum.Linear_1 == m2_sum.Linear_1 + assert m_sum.Linear_2 == m2_sum.Linear_2 + + assert m_sum != m2_sum assert m != m2 - assert m.Sum.parent == m - assert m2.Sum.parent == m2 + assert m_sum.parent == m + assert m2_sum.parent == m2 y2 = m2(x) assert y2.shape == (7, 12) diff --git a/tests/fluxion/test_module.py b/tests/fluxion/test_module.py index bf87650..fc82398 100644 --- a/tests/fluxion/test_module.py +++ b/tests/fluxion/test_module.py @@ -10,9 +10,12 @@ def test_module_get_path() -> None: fl.Sum(), ) - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" - assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" - assert chain.Sum_1.get_path() == "Chain.Sum_1" + sum_1 = chain.layer("Sum_1", fl.Sum) + linear_2 = sum_1.layer("Linear_2", fl.Linear) + + assert linear_2.get_path(parent=sum_1) == "Chain.Sum_1.Linear_2" + assert linear_2.get_path(parent=sum_1, top=sum_1) == "Sum.Linear_2" + assert sum_1.get_path() == "Chain.Sum_1" def test_module_basic_attributes() -> None: diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index 9ab1e7e..debeee5 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -85,13 +85,14 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch. def test_tokenizer_with_special_character(): - clip_tokenizer = fl.Chain(CLIPTokenizer()) - token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer) - new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42 + clip_tokenizer_chain = fl.Chain(CLIPTokenizer()) + original_clip_tokenizer = clip_tokenizer_chain.layer("CLIPTokenizer", CLIPTokenizer) + token_extender = TokenExtender(original_clip_tokenizer) + new_token_id = max(original_clip_tokenizer.token_to_id_mapping.values()) + 42 token_extender.add_token("*", new_token_id) - token_extender.inject(clip_tokenizer) + token_extender.inject(clip_tokenizer_chain) - adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer) + adapted_clip_tokenizer = clip_tokenizer_chain.ensure_find(CLIPTokenizer) assert torch.allclose( adapted_clip_tokenizer.encode("*"), diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py index 3e4553e..487f369 100644 --- a/tests/foundationals/latent_diffusion/test_freeu.py +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -4,6 +4,7 @@ import pytest import torch from refiners.fluxion import manual_seed +from refiners.fluxion.layers import Chain from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter @@ -33,7 +34,7 @@ def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None: def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None: - num_blocks = len(unet.UpBlocks) + num_blocks = len(unet.layer("UpBlocks", Chain)) with pytest.raises(AssertionError): SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1)) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 1a62217..e8afa41 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -16,10 +16,11 @@ from tests.foundationals.segment_anything.utils import ( ) from torch import Tensor +import refiners.fluxion.layers as fl from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad -from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention +from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer @@ -104,17 +105,22 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: manual_seed(seed=0) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) - attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore + attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) refiners_attention = FusedSelfAttention( embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device ) - refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore - refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore - refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore - refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore - refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w - refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h + + rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention) + linear_1 = refiners_attention.layer("Linear_1", fl.Linear) + linear_2 = refiners_attention.layer("Linear_2", fl.Linear) + + linear_1.weight = attention.qkv.weight + linear_1.bias = attention.qkv.bias + linear_2.weight = attention.proj.weight + linear_2.bias = attention.proj.bias + rpa.horizontal_embedding = attention.rel_pos_w + rpa.vertical_embedding = attention.rel_pos_h y_1 = attention(x) assert y_1.shape == x.shape