make __getattr__ on Module return object, not Any

PyTorch chose to make it Any because they expect its users' code
to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321

It is not the case for us, in Refiners having untyped code
goes contrary to one of our core principles.

Note that there is currently an open PR in PyTorch to
return `Module | Tensor`, but in practice this is not always
correct either: https://github.com/pytorch/pytorch/pull/115074

I also moved Residuals-related code from SD1 to latent_diffusion
because SDXL should not depend on SD1.
This commit is contained in:
Pierre Chapuis 2024-02-05 17:10:05 +01:00
parent ec401133f1
commit 471ef91d1c
26 changed files with 288 additions and 168 deletions

View file

@ -29,7 +29,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
)
unet = SD1UNet(in_channels=4)
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
controlnet = unet.Controlnet
controlnet = adapter.controlnet
condition = torch.randn(1, 3, 512, 512)
adapter.set_controlnet_condition(condition=condition)

View file

@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter:
target.set_timestep(timestep=timestep)
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
added_cond_kwargs = {}
if source_has_time_ids:
if isinstance(target, SDXLUNet):
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])

View file

@ -11,7 +11,7 @@ from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -136,7 +136,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight)
positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder)
embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight)
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
converted_source.update(rel_items)

View file

@ -43,14 +43,12 @@ class Adapter(Generic[T]):
), "Call the Chain constructor in the setup_adapter context."
self._target = [target]
if not isinstance(self.target, fl.ContextModule):
if isinstance(target, fl.ContextModule):
assert isinstance(target, fl.ContextModule)
with target.no_parent_refresh():
yield
else:
yield
return
_old_can_refresh_parent = target._can_refresh_parent
target._can_refresh_parent = False
yield
target._can_refresh_parent = _old_can_refresh_parent
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
"""Inject the adapter.

View file

@ -3,7 +3,7 @@ import re
import sys
import traceback
from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType
@ -362,6 +362,48 @@ class Chain(ContextModule):
recurse=recurse,
)
def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T:
"""Access a layer of the Chain given its type.
Example:
```py
# same as my_chain["Linear_2"], asserts it is a Linear
my_chain.layer("Linear_2", fl.Linear)
# same as my_chain[3], asserts it is a Linear
my_chain.layer(3, fl.Linear)
# probably won't work
my_chain.layer("Conv2d", fl.Linear)
# same as my_chain["foo"][42]["bar"],
# assuming bar is a MyType and all parents are Chains
my_chain.layer(("foo", 42, "bar"), fl.MyType)
```
Args:
key: The key or path of the layer.
layer_type: The type of the layer.
Yields:
The layer.
Raises:
AssertionError: If the layer doesn't exist or the type is invalid.
"""
if isinstance(key, (str, int)):
r = self[key]
assert isinstance(r, layer_type), f"layer {key} is {type(r)}, not {layer_type}"
return r
if len(key) == 0:
assert isinstance(self, layer_type), f"layer is {type(self)}, not {layer_type}"
return self
if len(key) == 1:
return self.layer(key[0], layer_type)
return self.layer(key[0], Chain).layer(key[1:], layer_type)
def layers(
self,
layer_type: type[T],
@ -576,6 +618,7 @@ class Chain(ContextModule):
require extra GPU memory since the weights are in the leaves and hence not copied.
"""
if hasattr(self, "_pre_structural_copy"):
assert callable(self._pre_structural_copy)
self._pre_structural_copy()
modules = [structural_copy(m) for m in self]
@ -586,6 +629,7 @@ class Chain(ContextModule):
clone.append(module=module)
if hasattr(clone, "_post_structural_copy"):
assert callable(clone._post_structural_copy)
clone._post_structural_copy(self)
return clone

View file

@ -1,11 +1,12 @@
import contextlib
import sys
from collections import defaultdict
from inspect import Parameter, signature
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Iterator, Sequence, TypedDict, TypeVar, cast
from torch import device as Device, dtype as DType
from torch import Tensor, device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule
from refiners.fluxion.context import Context, ContextProvider
@ -29,7 +30,13 @@ class Module(TorchModule):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
def __getattr__(self, name: str) -> Any:
def __getattr__(self, name: str) -> object:
# Note: PyTorch returns `Any` as of 2.2 and is considering
# going back to `Tensor | Module`, but the truth is it is
# impossible to type `__getattr__` correctly.
# Because PyTorch assumes its users write highly dynamic code,
# it returns Python's top type `Any`. In Refiners, static type
# checking is a core design value, hence we return `object` instead.
return super().__getattr__(name=name)
def __setattr__(self, name: str, value: Any) -> None:
@ -220,10 +227,19 @@ class ContextModule(Module):
return super().get_path(parent=parent or self.parent, top=top)
@contextlib.contextmanager
def no_parent_refresh(self) -> Iterator[None]:
_old_can_refresh_parent = self._can_refresh_parent
self._can_refresh_parent = False
yield
self._can_refresh_parent = _old_can_refresh_parent
class WeightedModule(Module):
"""A module with a weight (Tensor) attribute."""
weight: Tensor
@property
def device(self) -> Device:
"""Return the device of the module's weight."""

View file

@ -157,7 +157,7 @@ class Decoder(Chain):
)
resnet_layers[0].insert(1, attention_layer)
for _, layer in zip(range(3), resnet_layers[1:]):
channels: int = layer[-1].out_channels
channels: int = layer.layer(-1, Resnet).out_channels
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
super().__init__(
Conv2d(

View file

@ -73,7 +73,7 @@ class FreeUResidualConcatenator(fl.Concatenate):
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
assert len(backbone_scales) == len(skip_scales)
assert len(backbone_scales) <= len(target.UpBlocks)
assert len(backbone_scales) <= len(target.layer("UpBlocks", fl.Chain))
self.backbone_scales = backbone_scales
self.skip_scales = skip_scales
with self.setup_adapter(target):
@ -88,7 +88,7 @@ class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
def eject(self) -> None:
for n in range(len(self.backbone_scales)):
block = self.target.UpBlocks[n]
block = self.target.layer(("UpBlocks", n), fl.Chain)
concat = block.ensure_find(FreeUResidualConcatenator)
block.replace(concat, ResidualConcatenator(-n - 2))
super().eject()

View file

@ -307,11 +307,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
@property
def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear
return self.image_cross_attention.layer(("Distribute", 1, "Linear"), fl.Linear)
@property
def image_value_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[2].Linear
return self.image_cross_attention.layer(("Distribute", 2, "Linear"), fl.Linear)
@property
def scale(self) -> float:

View file

@ -9,17 +9,15 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.solvers import Solver
T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
class LatentDiffusionModel(fl.Module, ABC):
def __init__(
self,
unet: fl.Module,
unet: fl.Chain,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module,
clip_text_encoder: fl.Chain,
solver: Solver,
device: Device | str = "cpu",
dtype: DType = torch.float32,

View file

@ -1,5 +1,3 @@
from typing import Iterable, cast
from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.adapters.adapter import Adapter
@ -93,28 +91,31 @@ class Controlnet(Passthrough):
# We run the condition encoder at each step. Caching the result
# is not worth it as subsequent runs take virtually no time (FG-374).
self.DownBlocks[0].append(
self.layer(("DownBlocks", 0), Chain).append(
Residual(
UseContext("controlnet", f"condition_{name}"),
ConditionEncoder(device=device, dtype=dtype),
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
chain = residual_block.layer("Chain", Chain)
RangeAdapter2d(
target=chain.Conv2d_1,
target=chain.layer("Conv2d_1", Conv2d),
channels=residual_block.out_channels,
embedding_dim=1280,
context_key=f"timestep_embedding_{name}",
device=device,
dtype=dtype,
).inject(chain)
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
assert hasattr(block[0], "out_channels"), (
for n, block in enumerate(self.layer("DownBlocks", DownBlocks)):
assert isinstance(block, Chain)
b0 = block[0]
assert hasattr(b0, "out_channels"), (
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
f" {block[0]} does not."
f" {b0} does not."
)
out_channels: int = block[0].out_channels
assert isinstance(out_channels := b0.out_channels, int)
block.append(
Passthrough(
Conv2d(
@ -123,7 +124,7 @@ class Controlnet(Passthrough):
Lambda(self._store_nth_residual(n)),
)
)
self.MiddleBlock.append(
self.layer("MiddleBlock", MiddleBlock).append(
Passthrough(
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
Lambda(self._store_nth_residual(12)),
@ -166,6 +167,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
self.target.remove(self._controlnet[0])
super().eject()
@property
def controlnet(self) -> Controlnet:
return self._controlnet[0]
def init_context(self) -> Contexts:
return {"controlnet": {f"condition_{self.name}": None}}

View file

@ -25,7 +25,7 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
for n, feat in zip(self.residual_indices, self._features, strict=True):
block = self.target.DownBlocks[n]
block = self.target.layer(("DownBlocks", n), fl.Chain)
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
block.insert_before_type(ResidualAccumulator, feat)
@ -33,5 +33,5 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
def eject(self: "SD1T2IAdapter") -> None:
for n, feat in zip(self.residual_indices, self._features, strict=True):
self.target.DownBlocks[n].remove(feat)
self.target.layer(("DownBlocks", n), fl.Chain).remove(feat)
super().eject()

View file

@ -6,6 +6,11 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
from refiners.foundationals.latent_diffusion.unet import (
ResidualAccumulator,
ResidualBlock,
ResidualConcatenator,
)
class TimestepEncoder(fl.Passthrough):
@ -22,54 +27,6 @@ class TimestepEncoder(fl.Passthrough):
)
class ResidualBlock(fl.Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class CLIPLCrossAttention(CrossAttentionBlock2d):
def __init__(
self,
@ -205,34 +162,6 @@ class MiddleBlock(fl.Chain):
)
class ResidualAccumulator(fl.Passthrough):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)
class SD1UNet(fl.Chain):
"""Stable Diffusion 1.5 U-Net.
@ -275,9 +204,9 @@ class SD1UNet(fl.Chain):
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
chain = residual_block.layer("Chain", fl.Chain)
RangeAdapter2d(
target=chain.Conv2d_1,
target=chain.layer("Conv2d_1", fl.Conv2d),
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",

View file

@ -2,7 +2,7 @@ from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, SDXLUNet
from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures
@ -31,18 +31,19 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
for n, feat in zip(self.residual_indices, self._features, strict=False):
block = self.target.DownBlocks[n]
block = self.target.layer(("DownBlocks", n), fl.Chain)
sanity_check_t2i(block)
block.insert_before_type(ResidualAccumulator, feat)
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
sanity_check_t2i(self.target.MiddleBlock)
self.target.MiddleBlock.append(self._features[-1])
mid_block = self.target.layer("MiddleBlock", MiddleBlock)
sanity_check_t2i(mid_block)
mid_block.append(self._features[-1])
return super().inject(parent)
def eject(self: "SDXLT2IAdapter") -> None:
# See `inject` re: `strict=False`
for n, feat in zip(self.residual_indices, self._features, strict=False):
self.target.DownBlocks[n].remove(feat)
self.target.MiddleBlock.remove(self._features[-1])
self.target.layer(("DownBlocks", n), fl.Chain).remove(feat)
self.target.layer("MiddleBlock", MiddleBlock).remove(self._features[-1])
super().eject()

View file

@ -72,7 +72,9 @@ class DoubleTextEncoder(fl.Chain):
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings),
)
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(
parent=self.layer("Parallel", fl.Parallel)
)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text)

View file

@ -10,7 +10,7 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
RangeEncoder,
compute_sinusoidal_embedding,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
from refiners.foundationals.latent_diffusion.unet import (
ResidualAccumulator,
ResidualBlock,
ResidualConcatenator,
@ -267,9 +267,9 @@ class SDXLUNet(fl.Chain):
OutputBlock(device=device, dtype=dtype),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
chain = residual_block.layer("Chain", fl.Chain)
RangeAdapter2d(
target=chain.Conv2d_1,
target=chain.layer("Conv2d_1", fl.Conv2d),
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",

View file

@ -0,0 +1,79 @@
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
class ResidualBlock(fl.Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class ResidualAccumulator(fl.Passthrough):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)

View file

@ -188,6 +188,7 @@ class MaskEncoder(fl.Chain):
def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand(
no_mask_embedding = cast(Tensor, self.no_mask_embedding)
return no_mask_embedding.reshape(1, -1, 1, 1).expand(
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
)

View file

@ -22,8 +22,8 @@ def chain() -> Chain:
def test_weighted_module_adapter_insertion(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
parent = chain.layer("Chain", Chain)
adaptee = parent.layer("Linear", Linear)
adapter = DummyLinearAdapter(adaptee).inject(parent)
@ -39,7 +39,7 @@ def test_weighted_module_adapter_insertion(chain: Chain):
def test_chain_adapter_insertion(chain: Chain):
parent = chain
adaptee = parent.Chain
adaptee = parent.layer("Chain", Chain)
adapter = DummyChainAdapter(adaptee)
assert adaptee.parent == parent
@ -58,20 +58,20 @@ def test_chain_adapter_insertion(chain: Chain):
def test_weighted_module_adapter_structural_copy(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
parent = chain.layer("Chain", Chain)
adaptee = parent.layer("Linear", Linear)
DummyLinearAdapter(adaptee).inject(parent)
clone = chain.structural_copy()
cloned_adapter = clone.Chain.DummyLinearAdapter
cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter)
assert cloned_adapter.parent == clone.Chain
assert cloned_adapter.target == adaptee
def test_chain_adapter_structural_copy(chain: Chain):
# Chain adapters cannot be copied by default.
adapter = DummyChainAdapter(chain.Chain).inject()
adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject()
with pytest.raises(RuntimeError):
chain.structural_copy()

View file

@ -27,6 +27,7 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
assert conv_lora.scale == 1.0
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
# padding is set so the spatial dimensions are preserved
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
@ -39,10 +40,12 @@ def test_scale_setter(lora: LinearLora) -> None:
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
x = torch.randn(1, 320)
assert torch.allclose(lora(x), new_lora(x))
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
x = torch.randn(1, 16, 64, 64)
assert torch.allclose(conv_lora(x), new_conv_lora(x))

View file

@ -15,8 +15,9 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-
dtype = torch.float64
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
adaptee = chain.RangeEncoder.Linear_1
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
range_encoder = chain.layer("RangeEncoder", RangeEncoder)
adaptee = range_encoder.layer("Linear_1", Linear)
adapter = DummyLinearAdapter(adaptee).inject(range_encoder)
assert adapter.parent == chain.RangeEncoder

View file

@ -30,8 +30,9 @@ def test_chain_getitem_accessor() -> None:
def test_chain_find_parent():
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
subchain = chain.layer("Chain", fl.Chain)
assert chain.find_parent(chain.Chain.Linear) == chain.Chain
assert chain.find_parent(subchain.layer("Linear", fl.Linear)) == subchain
assert chain.find_parent(fl.Linear(1, 1)) is None
@ -64,17 +65,20 @@ def test_chain_walk() -> None:
fl.Chain(),
)
assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)]
sum_ = chain.layer("Sum", fl.Sum)
sum_chain = sum_.layer("Chain", fl.Chain)
assert list(chain.walk()) == [(sum_, chain), (chain.Chain, chain)]
assert list(chain.walk(fl.Linear)) == [
(chain.Sum.Chain.Linear, chain.Sum.Chain),
(chain.Sum.Linear, chain.Sum),
(sum_chain.Linear, sum_chain),
(sum_.Linear, sum_),
]
assert list(chain.walk(recurse=True)) == [
(chain.Sum, chain),
(chain.Sum.Chain, chain.Sum),
(chain.Sum.Chain.Linear, chain.Sum.Chain),
(chain.Sum.Linear, chain.Sum),
(sum_, chain),
(sum_chain, sum_),
(sum_chain.Linear, sum_chain),
(sum_.Linear, sum_),
(chain.Chain, chain),
]
@ -98,6 +102,29 @@ def test_chain_walk_stop_iteration() -> None:
assert len(list(chain.walk(predicate))) == 1
def test_chain_layer() -> None:
chain = fl.Chain(
fl.Sum(fl.Chain(), fl.Chain()),
)
sum_ = chain.layer(0, fl.Sum)
assert chain.layer("Sum", fl.Sum) == sum_
assert chain.layer("Sum", fl.Chain) == sum_
chain_2 = chain.layer((0, 1), fl.Chain)
assert chain.layer((0, 1)) == chain_2
assert chain.layer((0, "Chain_2"), fl.Chain) == chain_2
assert chain.layer(("Sum", "Chain_2"), fl.Chain) == chain_2
assert chain.layer((), fl.Chain) == chain
with pytest.raises(AssertionError):
chain.layer((0, 1), fl.Sum)
with pytest.raises(AssertionError):
chain.layer((), fl.Sum)
def test_chain_layers() -> None:
chain = fl.Chain(
fl.Chain(fl.Chain(fl.Chain())),
@ -214,11 +241,12 @@ def test_chain_replace() -> None:
fl.Linear(1, 1),
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
)
subchain = chain.layer("Chain", fl.Chain)
assert isinstance(chain.Chain[1], fl.Linear)
chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1))
assert isinstance(subchain[1], fl.Linear)
subchain.replace(subchain[1], fl.Conv2d(1, 1, 1))
assert len(chain) == 3
assert isinstance(chain.Chain[1], fl.Conv2d)
assert isinstance(subchain[1], fl.Conv2d)
def test_chain_structural_copy() -> None:
@ -236,15 +264,18 @@ def test_chain_structural_copy() -> None:
m2 = m.structural_copy()
assert m.Linear == m2.Linear
assert m.Sum.Linear_1 == m2.Sum.Linear_1
assert m.Sum.Linear_2 == m2.Sum.Linear_2
m_sum = m.layer("Sum", fl.Sum)
m2_sum = m2.layer("Sum", fl.Sum)
assert m.Sum != m2.Sum
assert m.Linear == m2.Linear
assert m_sum.Linear_1 == m2_sum.Linear_1
assert m_sum.Linear_2 == m2_sum.Linear_2
assert m_sum != m2_sum
assert m != m2
assert m.Sum.parent == m
assert m2.Sum.parent == m2
assert m_sum.parent == m
assert m2_sum.parent == m2
y2 = m2(x)
assert y2.shape == (7, 12)

View file

@ -10,9 +10,12 @@ def test_module_get_path() -> None:
fl.Sum(),
)
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
assert chain.Sum_1.get_path() == "Chain.Sum_1"
sum_1 = chain.layer("Sum_1", fl.Sum)
linear_2 = sum_1.layer("Linear_2", fl.Linear)
assert linear_2.get_path(parent=sum_1) == "Chain.Sum_1.Linear_2"
assert linear_2.get_path(parent=sum_1, top=sum_1) == "Sum.Linear_2"
assert sum_1.get_path() == "Chain.Sum_1"
def test_module_basic_attributes() -> None:

View file

@ -85,13 +85,14 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.
def test_tokenizer_with_special_character():
clip_tokenizer = fl.Chain(CLIPTokenizer())
token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer)
new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42
clip_tokenizer_chain = fl.Chain(CLIPTokenizer())
original_clip_tokenizer = clip_tokenizer_chain.layer("CLIPTokenizer", CLIPTokenizer)
token_extender = TokenExtender(original_clip_tokenizer)
new_token_id = max(original_clip_tokenizer.token_to_id_mapping.values()) + 42
token_extender.add_token("*", new_token_id)
token_extender.inject(clip_tokenizer)
token_extender.inject(clip_tokenizer_chain)
adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer)
adapted_clip_tokenizer = clip_tokenizer_chain.ensure_find(CLIPTokenizer)
assert torch.allclose(
adapted_clip_tokenizer.encode("*"),

View file

@ -4,6 +4,7 @@ import pytest
import torch
from refiners.fluxion import manual_seed
from refiners.fluxion.layers import Chain
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
@ -33,7 +34,7 @@ def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
num_blocks = len(unet.UpBlocks)
num_blocks = len(unet.layer("UpBlocks", Chain))
with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))

View file

@ -16,10 +16,11 @@ from tests.foundationals.segment_anything.utils import (
)
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
@ -104,17 +105,22 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
manual_seed(seed=0)
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn)
refiners_attention = FusedSelfAttention(
embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
)
refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore
refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore
refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore
refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore
refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w
refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h
rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention)
linear_1 = refiners_attention.layer("Linear_1", fl.Linear)
linear_2 = refiners_attention.layer("Linear_2", fl.Linear)
linear_1.weight = attention.qkv.weight
linear_1.bias = attention.qkv.bias
linear_2.weight = attention.proj.weight
linear_2.bias = attention.proj.bias
rpa.horizontal_embedding = attention.rel_pos_w
rpa.vertical_embedding = attention.rel_pos_h
y_1 = attention(x)
assert y_1.shape == x.shape