make __getattr__ on Module return object, not Any

PyTorch chose to make it Any because they expect its users' code
to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321

It is not the case for us, in Refiners having untyped code
goes contrary to one of our core principles.

Note that there is currently an open PR in PyTorch to
return `Module | Tensor`, but in practice this is not always
correct either: https://github.com/pytorch/pytorch/pull/115074

I also moved Residuals-related code from SD1 to latent_diffusion
because SDXL should not depend on SD1.
This commit is contained in:
Pierre Chapuis 2024-02-05 17:10:05 +01:00
parent ec401133f1
commit 471ef91d1c
26 changed files with 288 additions and 168 deletions

View file

@ -29,7 +29,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
) )
unet = SD1UNet(in_channels=4) unet = SD1UNet(in_channels=4)
adapter = SD1ControlnetAdapter(unet, name="mycn").inject() adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
controlnet = unet.Controlnet controlnet = adapter.controlnet
condition = torch.randn(1, 3, 512, 512) condition = torch.randn(1, 3, 512, 512)
adapter.set_controlnet_condition(condition=condition) adapter.set_controlnet_condition(condition=condition)

View file

@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter:
target.set_timestep(timestep=timestep) target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
added_cond_kwargs = {} added_cond_kwargs = {}
if source_has_time_ids: if isinstance(target, SDXLUNet):
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"]) target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"]) target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])

View file

@ -11,7 +11,7 @@ from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import SAMViTH from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -136,7 +136,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
) )
embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight) positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder)
embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight)
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
converted_source.update(rel_items) converted_source.update(rel_items)

View file

@ -43,14 +43,12 @@ class Adapter(Generic[T]):
), "Call the Chain constructor in the setup_adapter context." ), "Call the Chain constructor in the setup_adapter context."
self._target = [target] self._target = [target]
if not isinstance(self.target, fl.ContextModule): if isinstance(target, fl.ContextModule):
assert isinstance(target, fl.ContextModule)
with target.no_parent_refresh():
yield
else:
yield yield
return
_old_can_refresh_parent = target._can_refresh_parent
target._can_refresh_parent = False
yield
target._can_refresh_parent = _old_can_refresh_parent
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter: def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
"""Inject the adapter. """Inject the adapter.

View file

@ -3,7 +3,7 @@ import re
import sys import sys
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
import torch import torch
from torch import Tensor, cat, device as Device, dtype as DType from torch import Tensor, cat, device as Device, dtype as DType
@ -362,6 +362,48 @@ class Chain(ContextModule):
recurse=recurse, recurse=recurse,
) )
def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T:
"""Access a layer of the Chain given its type.
Example:
```py
# same as my_chain["Linear_2"], asserts it is a Linear
my_chain.layer("Linear_2", fl.Linear)
# same as my_chain[3], asserts it is a Linear
my_chain.layer(3, fl.Linear)
# probably won't work
my_chain.layer("Conv2d", fl.Linear)
# same as my_chain["foo"][42]["bar"],
# assuming bar is a MyType and all parents are Chains
my_chain.layer(("foo", 42, "bar"), fl.MyType)
```
Args:
key: The key or path of the layer.
layer_type: The type of the layer.
Yields:
The layer.
Raises:
AssertionError: If the layer doesn't exist or the type is invalid.
"""
if isinstance(key, (str, int)):
r = self[key]
assert isinstance(r, layer_type), f"layer {key} is {type(r)}, not {layer_type}"
return r
if len(key) == 0:
assert isinstance(self, layer_type), f"layer is {type(self)}, not {layer_type}"
return self
if len(key) == 1:
return self.layer(key[0], layer_type)
return self.layer(key[0], Chain).layer(key[1:], layer_type)
def layers( def layers(
self, self,
layer_type: type[T], layer_type: type[T],
@ -576,6 +618,7 @@ class Chain(ContextModule):
require extra GPU memory since the weights are in the leaves and hence not copied. require extra GPU memory since the weights are in the leaves and hence not copied.
""" """
if hasattr(self, "_pre_structural_copy"): if hasattr(self, "_pre_structural_copy"):
assert callable(self._pre_structural_copy)
self._pre_structural_copy() self._pre_structural_copy()
modules = [structural_copy(m) for m in self] modules = [structural_copy(m) for m in self]
@ -586,6 +629,7 @@ class Chain(ContextModule):
clone.append(module=module) clone.append(module=module)
if hasattr(clone, "_post_structural_copy"): if hasattr(clone, "_post_structural_copy"):
assert callable(clone._post_structural_copy)
clone._post_structural_copy(self) clone._post_structural_copy(self)
return clone return clone

View file

@ -1,11 +1,12 @@
import contextlib
import sys import sys
from collections import defaultdict from collections import defaultdict
from inspect import Parameter, signature from inspect import Parameter, signature
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Iterator, Sequence, TypedDict, TypeVar, cast
from torch import device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule from torch.nn.modules.module import Module as TorchModule
from refiners.fluxion.context import Context, ContextProvider from refiners.fluxion.context import Context, ContextProvider
@ -29,7 +30,13 @@ class Module(TorchModule):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType] super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> object:
# Note: PyTorch returns `Any` as of 2.2 and is considering
# going back to `Tensor | Module`, but the truth is it is
# impossible to type `__getattr__` correctly.
# Because PyTorch assumes its users write highly dynamic code,
# it returns Python's top type `Any`. In Refiners, static type
# checking is a core design value, hence we return `object` instead.
return super().__getattr__(name=name) return super().__getattr__(name=name)
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
@ -220,10 +227,19 @@ class ContextModule(Module):
return super().get_path(parent=parent or self.parent, top=top) return super().get_path(parent=parent or self.parent, top=top)
@contextlib.contextmanager
def no_parent_refresh(self) -> Iterator[None]:
_old_can_refresh_parent = self._can_refresh_parent
self._can_refresh_parent = False
yield
self._can_refresh_parent = _old_can_refresh_parent
class WeightedModule(Module): class WeightedModule(Module):
"""A module with a weight (Tensor) attribute.""" """A module with a weight (Tensor) attribute."""
weight: Tensor
@property @property
def device(self) -> Device: def device(self) -> Device:
"""Return the device of the module's weight.""" """Return the device of the module's weight."""

View file

@ -157,7 +157,7 @@ class Decoder(Chain):
) )
resnet_layers[0].insert(1, attention_layer) resnet_layers[0].insert(1, attention_layer)
for _, layer in zip(range(3), resnet_layers[1:]): for _, layer in zip(range(3), resnet_layers[1:]):
channels: int = layer[-1].out_channels channels: int = layer.layer(-1, Resnet).out_channels
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype)) layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
super().__init__( super().__init__(
Conv2d( Conv2d(

View file

@ -73,7 +73,7 @@ class FreeUResidualConcatenator(fl.Concatenate):
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]): class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None: def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
assert len(backbone_scales) == len(skip_scales) assert len(backbone_scales) == len(skip_scales)
assert len(backbone_scales) <= len(target.UpBlocks) assert len(backbone_scales) <= len(target.layer("UpBlocks", fl.Chain))
self.backbone_scales = backbone_scales self.backbone_scales = backbone_scales
self.skip_scales = skip_scales self.skip_scales = skip_scales
with self.setup_adapter(target): with self.setup_adapter(target):
@ -88,7 +88,7 @@ class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
def eject(self) -> None: def eject(self) -> None:
for n in range(len(self.backbone_scales)): for n in range(len(self.backbone_scales)):
block = self.target.UpBlocks[n] block = self.target.layer(("UpBlocks", n), fl.Chain)
concat = block.ensure_find(FreeUResidualConcatenator) concat = block.ensure_find(FreeUResidualConcatenator)
block.replace(concat, ResidualConcatenator(-n - 2)) block.replace(concat, ResidualConcatenator(-n - 2))
super().eject() super().eject()

View file

@ -307,11 +307,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
@property @property
def image_key_projection(self) -> fl.Linear: def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear return self.image_cross_attention.layer(("Distribute", 1, "Linear"), fl.Linear)
@property @property
def image_value_projection(self) -> fl.Linear: def image_value_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[2].Linear return self.image_cross_attention.layer(("Distribute", 2, "Linear"), fl.Linear)
@property @property
def scale(self) -> float: def scale(self) -> float:

View file

@ -9,17 +9,15 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.solvers import Solver from refiners.foundationals.latent_diffusion.solvers import Solver
T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
class LatentDiffusionModel(fl.Module, ABC): class LatentDiffusionModel(fl.Module, ABC):
def __init__( def __init__(
self, self,
unet: fl.Module, unet: fl.Chain,
lda: LatentDiffusionAutoencoder, lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module, clip_text_encoder: fl.Chain,
solver: Solver, solver: Solver,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,

View file

@ -1,5 +1,3 @@
from typing import Iterable, cast
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
@ -93,28 +91,31 @@ class Controlnet(Passthrough):
# We run the condition encoder at each step. Caching the result # We run the condition encoder at each step. Caching the result
# is not worth it as subsequent runs take virtually no time (FG-374). # is not worth it as subsequent runs take virtually no time (FG-374).
self.DownBlocks[0].append(
self.layer(("DownBlocks", 0), Chain).append(
Residual( Residual(
UseContext("controlnet", f"condition_{name}"), UseContext("controlnet", f"condition_{name}"),
ConditionEncoder(device=device, dtype=dtype), ConditionEncoder(device=device, dtype=dtype),
), ),
) )
for residual_block in self.layers(ResidualBlock): for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain chain = residual_block.layer("Chain", Chain)
RangeAdapter2d( RangeAdapter2d(
target=chain.Conv2d_1, target=chain.layer("Conv2d_1", Conv2d),
channels=residual_block.out_channels, channels=residual_block.out_channels,
embedding_dim=1280, embedding_dim=1280,
context_key=f"timestep_embedding_{name}", context_key=f"timestep_embedding_{name}",
device=device, device=device,
dtype=dtype, dtype=dtype,
).inject(chain) ).inject(chain)
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)): for n, block in enumerate(self.layer("DownBlocks", DownBlocks)):
assert hasattr(block[0], "out_channels"), ( assert isinstance(block, Chain)
b0 = block[0]
assert hasattr(b0, "out_channels"), (
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`," "The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
f" {block[0]} does not." f" {b0} does not."
) )
out_channels: int = block[0].out_channels assert isinstance(out_channels := b0.out_channels, int)
block.append( block.append(
Passthrough( Passthrough(
Conv2d( Conv2d(
@ -123,7 +124,7 @@ class Controlnet(Passthrough):
Lambda(self._store_nth_residual(n)), Lambda(self._store_nth_residual(n)),
) )
) )
self.MiddleBlock.append( self.layer("MiddleBlock", MiddleBlock).append(
Passthrough( Passthrough(
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype), Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
Lambda(self._store_nth_residual(12)), Lambda(self._store_nth_residual(12)),
@ -166,6 +167,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
self.target.remove(self._controlnet[0]) self.target.remove(self._controlnet[0])
super().eject() super().eject()
@property
def controlnet(self) -> Controlnet:
return self._controlnet[0]
def init_context(self) -> Contexts: def init_context(self) -> Contexts:
return {"controlnet": {f"condition_{self.name}": None}} return {"controlnet": {f"condition_{self.name}": None}}

View file

@ -25,7 +25,7 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter": def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
for n, feat in zip(self.residual_indices, self._features, strict=True): for n, feat in zip(self.residual_indices, self._features, strict=True):
block = self.target.DownBlocks[n] block = self.target.layer(("DownBlocks", n), fl.Chain)
for t2i_layer in block.layers(layer_type=T2IFeatures): for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
block.insert_before_type(ResidualAccumulator, feat) block.insert_before_type(ResidualAccumulator, feat)
@ -33,5 +33,5 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
def eject(self: "SD1T2IAdapter") -> None: def eject(self: "SD1T2IAdapter") -> None:
for n, feat in zip(self.residual_indices, self._features, strict=True): for n, feat in zip(self.residual_indices, self._features, strict=True):
self.target.DownBlocks[n].remove(feat) self.target.layer(("DownBlocks", n), fl.Chain).remove(feat)
super().eject() super().eject()

View file

@ -6,6 +6,11 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
from refiners.foundationals.latent_diffusion.unet import (
ResidualAccumulator,
ResidualBlock,
ResidualConcatenator,
)
class TimestepEncoder(fl.Passthrough): class TimestepEncoder(fl.Passthrough):
@ -22,54 +27,6 @@ class TimestepEncoder(fl.Passthrough):
) )
class ResidualBlock(fl.Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class CLIPLCrossAttention(CrossAttentionBlock2d): class CLIPLCrossAttention(CrossAttentionBlock2d):
def __init__( def __init__(
self, self,
@ -205,34 +162,6 @@ class MiddleBlock(fl.Chain):
) )
class ResidualAccumulator(fl.Passthrough):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)
class SD1UNet(fl.Chain): class SD1UNet(fl.Chain):
"""Stable Diffusion 1.5 U-Net. """Stable Diffusion 1.5 U-Net.
@ -275,9 +204,9 @@ class SD1UNet(fl.Chain):
), ),
) )
for residual_block in self.layers(ResidualBlock): for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain chain = residual_block.layer("Chain", fl.Chain)
RangeAdapter2d( RangeAdapter2d(
target=chain.Conv2d_1, target=chain.layer("Conv2d_1", fl.Conv2d),
channels=residual_block.out_channels, channels=residual_block.out_channels,
embedding_dim=1280, embedding_dim=1280,
context_key="timestep_embedding", context_key="timestep_embedding",

View file

@ -2,7 +2,7 @@ from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, SDXLUNet
from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures
@ -31,18 +31,19 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below) # Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
for n, feat in zip(self.residual_indices, self._features, strict=False): for n, feat in zip(self.residual_indices, self._features, strict=False):
block = self.target.DownBlocks[n] block = self.target.layer(("DownBlocks", n), fl.Chain)
sanity_check_t2i(block) sanity_check_t2i(block)
block.insert_before_type(ResidualAccumulator, feat) block.insert_before_type(ResidualAccumulator, feat)
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append # Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
sanity_check_t2i(self.target.MiddleBlock) mid_block = self.target.layer("MiddleBlock", MiddleBlock)
self.target.MiddleBlock.append(self._features[-1]) sanity_check_t2i(mid_block)
mid_block.append(self._features[-1])
return super().inject(parent) return super().inject(parent)
def eject(self: "SDXLT2IAdapter") -> None: def eject(self: "SDXLT2IAdapter") -> None:
# See `inject` re: `strict=False` # See `inject` re: `strict=False`
for n, feat in zip(self.residual_indices, self._features, strict=False): for n, feat in zip(self.residual_indices, self._features, strict=False):
self.target.DownBlocks[n].remove(feat) self.target.layer(("DownBlocks", n), fl.Chain).remove(feat)
self.target.MiddleBlock.remove(self._features[-1]) self.target.layer("MiddleBlock", MiddleBlock).remove(self._features[-1])
super().eject() super().eject()

View file

@ -72,7 +72,9 @@ class DoubleTextEncoder(fl.Chain):
fl.Parallel(text_encoder_l[:-2], text_encoder_g), fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings), fl.Lambda(func=self.concatenate_embeddings),
) )
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel) TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(
parent=self.layer("Parallel", fl.Parallel)
)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text) return super().__call__(text)

View file

@ -10,7 +10,7 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
RangeEncoder, RangeEncoder,
compute_sinusoidal_embedding, compute_sinusoidal_embedding,
) )
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( from refiners.foundationals.latent_diffusion.unet import (
ResidualAccumulator, ResidualAccumulator,
ResidualBlock, ResidualBlock,
ResidualConcatenator, ResidualConcatenator,
@ -267,9 +267,9 @@ class SDXLUNet(fl.Chain):
OutputBlock(device=device, dtype=dtype), OutputBlock(device=device, dtype=dtype),
) )
for residual_block in self.layers(ResidualBlock): for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain chain = residual_block.layer("Chain", fl.Chain)
RangeAdapter2d( RangeAdapter2d(
target=chain.Conv2d_1, target=chain.layer("Conv2d_1", fl.Conv2d),
channels=residual_block.out_channels, channels=residual_block.out_channels,
embedding_dim=1280, embedding_dim=1280,
context_key="timestep_embedding", context_key="timestep_embedding",

View file

@ -0,0 +1,79 @@
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
class ResidualBlock(fl.Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class ResidualAccumulator(fl.Passthrough):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)

View file

@ -188,6 +188,7 @@ class MaskEncoder(fl.Chain):
def get_no_mask_dense_embedding( def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1 self, image_embedding_size: tuple[int, int], batch_size: int = 1
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]: ) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand( no_mask_embedding = cast(Tensor, self.no_mask_embedding)
return no_mask_embedding.reshape(1, -1, 1, 1).expand(
batch_size, -1, image_embedding_size[0], image_embedding_size[1] batch_size, -1, image_embedding_size[0], image_embedding_size[1]
) )

View file

@ -22,8 +22,8 @@ def chain() -> Chain:
def test_weighted_module_adapter_insertion(chain: Chain): def test_weighted_module_adapter_insertion(chain: Chain):
parent = chain.Chain parent = chain.layer("Chain", Chain)
adaptee = parent.Linear adaptee = parent.layer("Linear", Linear)
adapter = DummyLinearAdapter(adaptee).inject(parent) adapter = DummyLinearAdapter(adaptee).inject(parent)
@ -39,7 +39,7 @@ def test_weighted_module_adapter_insertion(chain: Chain):
def test_chain_adapter_insertion(chain: Chain): def test_chain_adapter_insertion(chain: Chain):
parent = chain parent = chain
adaptee = parent.Chain adaptee = parent.layer("Chain", Chain)
adapter = DummyChainAdapter(adaptee) adapter = DummyChainAdapter(adaptee)
assert adaptee.parent == parent assert adaptee.parent == parent
@ -58,20 +58,20 @@ def test_chain_adapter_insertion(chain: Chain):
def test_weighted_module_adapter_structural_copy(chain: Chain): def test_weighted_module_adapter_structural_copy(chain: Chain):
parent = chain.Chain parent = chain.layer("Chain", Chain)
adaptee = parent.Linear adaptee = parent.layer("Linear", Linear)
DummyLinearAdapter(adaptee).inject(parent) DummyLinearAdapter(adaptee).inject(parent)
clone = chain.structural_copy() clone = chain.structural_copy()
cloned_adapter = clone.Chain.DummyLinearAdapter cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter)
assert cloned_adapter.parent == clone.Chain assert cloned_adapter.parent == clone.Chain
assert cloned_adapter.target == adaptee assert cloned_adapter.target == adaptee
def test_chain_adapter_structural_copy(chain: Chain): def test_chain_adapter_structural_copy(chain: Chain):
# Chain adapters cannot be copied by default. # Chain adapters cannot be copied by default.
adapter = DummyChainAdapter(chain.Chain).inject() adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
chain.structural_copy() chain.structural_copy()

View file

@ -27,6 +27,7 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert conv_lora.scale == 1.0 assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16 assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8 assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1) assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved # padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1) assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
@ -39,10 +40,12 @@ def test_scale_setter(lora: LinearLora) -> None:
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None: def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight) new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320) x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x)) assert torch.allclose(lora(x), new_lora(x))
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight) new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64) x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x)) assert torch.allclose(conv_lora(x), new_conv_lora(x))

View file

@ -15,8 +15,9 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-
dtype = torch.float64 dtype = torch.float64
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype)) chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
adaptee = chain.RangeEncoder.Linear_1 range_encoder = chain.layer("RangeEncoder", RangeEncoder)
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder) adaptee = range_encoder.layer("Linear_1", Linear)
adapter = DummyLinearAdapter(adaptee).inject(range_encoder)
assert adapter.parent == chain.RangeEncoder assert adapter.parent == chain.RangeEncoder

View file

@ -30,8 +30,9 @@ def test_chain_getitem_accessor() -> None:
def test_chain_find_parent(): def test_chain_find_parent():
chain = fl.Chain(fl.Chain(fl.Linear(1, 1))) chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
subchain = chain.layer("Chain", fl.Chain)
assert chain.find_parent(chain.Chain.Linear) == chain.Chain assert chain.find_parent(subchain.layer("Linear", fl.Linear)) == subchain
assert chain.find_parent(fl.Linear(1, 1)) is None assert chain.find_parent(fl.Linear(1, 1)) is None
@ -64,17 +65,20 @@ def test_chain_walk() -> None:
fl.Chain(), fl.Chain(),
) )
assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)] sum_ = chain.layer("Sum", fl.Sum)
sum_chain = sum_.layer("Chain", fl.Chain)
assert list(chain.walk()) == [(sum_, chain), (chain.Chain, chain)]
assert list(chain.walk(fl.Linear)) == [ assert list(chain.walk(fl.Linear)) == [
(chain.Sum.Chain.Linear, chain.Sum.Chain), (sum_chain.Linear, sum_chain),
(chain.Sum.Linear, chain.Sum), (sum_.Linear, sum_),
] ]
assert list(chain.walk(recurse=True)) == [ assert list(chain.walk(recurse=True)) == [
(chain.Sum, chain), (sum_, chain),
(chain.Sum.Chain, chain.Sum), (sum_chain, sum_),
(chain.Sum.Chain.Linear, chain.Sum.Chain), (sum_chain.Linear, sum_chain),
(chain.Sum.Linear, chain.Sum), (sum_.Linear, sum_),
(chain.Chain, chain), (chain.Chain, chain),
] ]
@ -98,6 +102,29 @@ def test_chain_walk_stop_iteration() -> None:
assert len(list(chain.walk(predicate))) == 1 assert len(list(chain.walk(predicate))) == 1
def test_chain_layer() -> None:
chain = fl.Chain(
fl.Sum(fl.Chain(), fl.Chain()),
)
sum_ = chain.layer(0, fl.Sum)
assert chain.layer("Sum", fl.Sum) == sum_
assert chain.layer("Sum", fl.Chain) == sum_
chain_2 = chain.layer((0, 1), fl.Chain)
assert chain.layer((0, 1)) == chain_2
assert chain.layer((0, "Chain_2"), fl.Chain) == chain_2
assert chain.layer(("Sum", "Chain_2"), fl.Chain) == chain_2
assert chain.layer((), fl.Chain) == chain
with pytest.raises(AssertionError):
chain.layer((0, 1), fl.Sum)
with pytest.raises(AssertionError):
chain.layer((), fl.Sum)
def test_chain_layers() -> None: def test_chain_layers() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())), fl.Chain(fl.Chain(fl.Chain())),
@ -214,11 +241,12 @@ def test_chain_replace() -> None:
fl.Linear(1, 1), fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)), fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
) )
subchain = chain.layer("Chain", fl.Chain)
assert isinstance(chain.Chain[1], fl.Linear) assert isinstance(subchain[1], fl.Linear)
chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1)) subchain.replace(subchain[1], fl.Conv2d(1, 1, 1))
assert len(chain) == 3 assert len(chain) == 3
assert isinstance(chain.Chain[1], fl.Conv2d) assert isinstance(subchain[1], fl.Conv2d)
def test_chain_structural_copy() -> None: def test_chain_structural_copy() -> None:
@ -236,15 +264,18 @@ def test_chain_structural_copy() -> None:
m2 = m.structural_copy() m2 = m.structural_copy()
assert m.Linear == m2.Linear m_sum = m.layer("Sum", fl.Sum)
assert m.Sum.Linear_1 == m2.Sum.Linear_1 m2_sum = m2.layer("Sum", fl.Sum)
assert m.Sum.Linear_2 == m2.Sum.Linear_2
assert m.Sum != m2.Sum assert m.Linear == m2.Linear
assert m_sum.Linear_1 == m2_sum.Linear_1
assert m_sum.Linear_2 == m2_sum.Linear_2
assert m_sum != m2_sum
assert m != m2 assert m != m2
assert m.Sum.parent == m assert m_sum.parent == m
assert m2.Sum.parent == m2 assert m2_sum.parent == m2
y2 = m2(x) y2 = m2(x)
assert y2.shape == (7, 12) assert y2.shape == (7, 12)

View file

@ -10,9 +10,12 @@ def test_module_get_path() -> None:
fl.Sum(), fl.Sum(),
) )
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2" sum_1 = chain.layer("Sum_1", fl.Sum)
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2" linear_2 = sum_1.layer("Linear_2", fl.Linear)
assert chain.Sum_1.get_path() == "Chain.Sum_1"
assert linear_2.get_path(parent=sum_1) == "Chain.Sum_1.Linear_2"
assert linear_2.get_path(parent=sum_1, top=sum_1) == "Sum.Linear_2"
assert sum_1.get_path() == "Chain.Sum_1"
def test_module_basic_attributes() -> None: def test_module_basic_attributes() -> None:

View file

@ -85,13 +85,14 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.
def test_tokenizer_with_special_character(): def test_tokenizer_with_special_character():
clip_tokenizer = fl.Chain(CLIPTokenizer()) clip_tokenizer_chain = fl.Chain(CLIPTokenizer())
token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer) original_clip_tokenizer = clip_tokenizer_chain.layer("CLIPTokenizer", CLIPTokenizer)
new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42 token_extender = TokenExtender(original_clip_tokenizer)
new_token_id = max(original_clip_tokenizer.token_to_id_mapping.values()) + 42
token_extender.add_token("*", new_token_id) token_extender.add_token("*", new_token_id)
token_extender.inject(clip_tokenizer) token_extender.inject(clip_tokenizer_chain)
adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer) adapted_clip_tokenizer = clip_tokenizer_chain.ensure_find(CLIPTokenizer)
assert torch.allclose( assert torch.allclose(
adapted_clip_tokenizer.encode("*"), adapted_clip_tokenizer.encode("*"),

View file

@ -4,6 +4,7 @@ import pytest
import torch import torch
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.layers import Chain
from refiners.fluxion.utils import no_grad from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
@ -33,7 +34,7 @@ def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None: def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
num_blocks = len(unet.UpBlocks) num_blocks = len(unet.layer("UpBlocks", Chain))
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1)) SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))

View file

@ -16,10 +16,11 @@ from tests.foundationals.segment_anything.utils import (
) )
from torch import Tensor from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
@ -104,17 +105,22 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
manual_seed(seed=0) manual_seed(seed=0)
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn)
refiners_attention = FusedSelfAttention( refiners_attention = FusedSelfAttention(
embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
) )
refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore
refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention)
refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore linear_1 = refiners_attention.layer("Linear_1", fl.Linear)
refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore linear_2 = refiners_attention.layer("Linear_2", fl.Linear)
refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w
refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h linear_1.weight = attention.qkv.weight
linear_1.bias = attention.qkv.bias
linear_2.weight = attention.proj.weight
linear_2.bias = attention.proj.bias
rpa.horizontal_embedding = attention.rel_pos_w
rpa.vertical_embedding = attention.rel_pos_h
y_1 = attention(x) y_1 = attention(x)
assert y_1.shape == x.shape assert y_1.shape == x.shape