mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
make __getattr__
on Module return object, not Any
PyTorch chose to make it Any because they expect its users' code to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321 It is not the case for us, in Refiners having untyped code goes contrary to one of our core principles. Note that there is currently an open PR in PyTorch to return `Module | Tensor`, but in practice this is not always correct either: https://github.com/pytorch/pytorch/pull/115074 I also moved Residuals-related code from SD1 to latent_diffusion because SDXL should not depend on SD1.
This commit is contained in:
parent
ec401133f1
commit
471ef91d1c
|
@ -29,7 +29,7 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||||
)
|
)
|
||||||
unet = SD1UNet(in_channels=4)
|
unet = SD1UNet(in_channels=4)
|
||||||
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
||||||
controlnet = unet.Controlnet
|
controlnet = adapter.controlnet
|
||||||
|
|
||||||
condition = torch.randn(1, 3, 512, 512)
|
condition = torch.randn(1, 3, 512, 512)
|
||||||
adapter.set_controlnet_condition(condition=condition)
|
adapter.set_controlnet_condition(condition=condition)
|
||||||
|
|
|
@ -39,7 +39,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
target.set_timestep(timestep=timestep)
|
target.set_timestep(timestep=timestep)
|
||||||
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
target.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
added_cond_kwargs = {}
|
added_cond_kwargs = {}
|
||||||
if source_has_time_ids:
|
if isinstance(target, SDXLUNet):
|
||||||
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||||
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
target.set_time_ids(time_ids=added_cond_kwargs["time_ids"])
|
||||||
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
|
from refiners.fluxion.utils import load_tensors, manual_seed, save_to_safetensors
|
||||||
from refiners.foundationals.segment_anything.image_encoder import SAMViTH
|
from refiners.foundationals.segment_anything.image_encoder import PositionalEncoder, SAMViTH
|
||||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||||
|
|
||||||
|
@ -136,7 +136,8 @@ def convert_vit(vit: nn.Module) -> dict[str, Tensor]:
|
||||||
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
embed = pos_embed.reshape_as(refiners_sam_vit_h.PositionalEncoder.Parameter.weight)
|
positional_encoder = refiners_sam_vit_h.layer("PositionalEncoder", PositionalEncoder)
|
||||||
|
embed = pos_embed.reshape_as(positional_encoder.layer("Parameter", fl.Parameter).weight)
|
||||||
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
|
converted_source["PositionalEncoder.Parameter.weight"] = embed # type: ignore
|
||||||
converted_source.update(rel_items)
|
converted_source.update(rel_items)
|
||||||
|
|
||||||
|
|
|
@ -43,14 +43,12 @@ class Adapter(Generic[T]):
|
||||||
), "Call the Chain constructor in the setup_adapter context."
|
), "Call the Chain constructor in the setup_adapter context."
|
||||||
self._target = [target]
|
self._target = [target]
|
||||||
|
|
||||||
if not isinstance(self.target, fl.ContextModule):
|
if isinstance(target, fl.ContextModule):
|
||||||
|
assert isinstance(target, fl.ContextModule)
|
||||||
|
with target.no_parent_refresh():
|
||||||
|
yield
|
||||||
|
else:
|
||||||
yield
|
yield
|
||||||
return
|
|
||||||
|
|
||||||
_old_can_refresh_parent = target._can_refresh_parent
|
|
||||||
target._can_refresh_parent = False
|
|
||||||
yield
|
|
||||||
target._can_refresh_parent = _old_can_refresh_parent
|
|
||||||
|
|
||||||
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
||||||
"""Inject the adapter.
|
"""Inject the adapter.
|
||||||
|
|
|
@ -3,7 +3,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
@ -362,6 +362,48 @@ class Chain(ContextModule):
|
||||||
recurse=recurse,
|
recurse=recurse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def layer(self, key: str | int | Sequence[str | int], layer_type: type[T] = Module) -> T:
|
||||||
|
"""Access a layer of the Chain given its type.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
# same as my_chain["Linear_2"], asserts it is a Linear
|
||||||
|
my_chain.layer("Linear_2", fl.Linear)
|
||||||
|
|
||||||
|
|
||||||
|
# same as my_chain[3], asserts it is a Linear
|
||||||
|
my_chain.layer(3, fl.Linear)
|
||||||
|
|
||||||
|
# probably won't work
|
||||||
|
my_chain.layer("Conv2d", fl.Linear)
|
||||||
|
|
||||||
|
|
||||||
|
# same as my_chain["foo"][42]["bar"],
|
||||||
|
# assuming bar is a MyType and all parents are Chains
|
||||||
|
my_chain.layer(("foo", 42, "bar"), fl.MyType)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key or path of the layer.
|
||||||
|
layer_type: The type of the layer.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The layer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the layer doesn't exist or the type is invalid.
|
||||||
|
"""
|
||||||
|
if isinstance(key, (str, int)):
|
||||||
|
r = self[key]
|
||||||
|
assert isinstance(r, layer_type), f"layer {key} is {type(r)}, not {layer_type}"
|
||||||
|
return r
|
||||||
|
if len(key) == 0:
|
||||||
|
assert isinstance(self, layer_type), f"layer is {type(self)}, not {layer_type}"
|
||||||
|
return self
|
||||||
|
if len(key) == 1:
|
||||||
|
return self.layer(key[0], layer_type)
|
||||||
|
return self.layer(key[0], Chain).layer(key[1:], layer_type)
|
||||||
|
|
||||||
def layers(
|
def layers(
|
||||||
self,
|
self,
|
||||||
layer_type: type[T],
|
layer_type: type[T],
|
||||||
|
@ -576,6 +618,7 @@ class Chain(ContextModule):
|
||||||
require extra GPU memory since the weights are in the leaves and hence not copied.
|
require extra GPU memory since the weights are in the leaves and hence not copied.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "_pre_structural_copy"):
|
if hasattr(self, "_pre_structural_copy"):
|
||||||
|
assert callable(self._pre_structural_copy)
|
||||||
self._pre_structural_copy()
|
self._pre_structural_copy()
|
||||||
|
|
||||||
modules = [structural_copy(m) for m in self]
|
modules = [structural_copy(m) for m in self]
|
||||||
|
@ -586,6 +629,7 @@ class Chain(ContextModule):
|
||||||
clone.append(module=module)
|
clone.append(module=module)
|
||||||
|
|
||||||
if hasattr(clone, "_post_structural_copy"):
|
if hasattr(clone, "_post_structural_copy"):
|
||||||
|
assert callable(clone._post_structural_copy)
|
||||||
clone._post_structural_copy(self)
|
clone._post_structural_copy(self)
|
||||||
|
|
||||||
return clone
|
return clone
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
|
import contextlib
|
||||||
import sys
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from inspect import Parameter, signature
|
from inspect import Parameter, signature
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Iterator, Sequence, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
from torch import device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from torch.nn.modules.module import Module as TorchModule
|
from torch.nn.modules.module import Module as TorchModule
|
||||||
|
|
||||||
from refiners.fluxion.context import Context, ContextProvider
|
from refiners.fluxion.context import Context, ContextProvider
|
||||||
|
@ -29,7 +30,13 @@ class Module(TorchModule):
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
|
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> object:
|
||||||
|
# Note: PyTorch returns `Any` as of 2.2 and is considering
|
||||||
|
# going back to `Tensor | Module`, but the truth is it is
|
||||||
|
# impossible to type `__getattr__` correctly.
|
||||||
|
# Because PyTorch assumes its users write highly dynamic code,
|
||||||
|
# it returns Python's top type `Any`. In Refiners, static type
|
||||||
|
# checking is a core design value, hence we return `object` instead.
|
||||||
return super().__getattr__(name=name)
|
return super().__getattr__(name=name)
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
@ -220,10 +227,19 @@ class ContextModule(Module):
|
||||||
|
|
||||||
return super().get_path(parent=parent or self.parent, top=top)
|
return super().get_path(parent=parent or self.parent, top=top)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def no_parent_refresh(self) -> Iterator[None]:
|
||||||
|
_old_can_refresh_parent = self._can_refresh_parent
|
||||||
|
self._can_refresh_parent = False
|
||||||
|
yield
|
||||||
|
self._can_refresh_parent = _old_can_refresh_parent
|
||||||
|
|
||||||
|
|
||||||
class WeightedModule(Module):
|
class WeightedModule(Module):
|
||||||
"""A module with a weight (Tensor) attribute."""
|
"""A module with a weight (Tensor) attribute."""
|
||||||
|
|
||||||
|
weight: Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
"""Return the device of the module's weight."""
|
"""Return the device of the module's weight."""
|
||||||
|
|
|
@ -157,7 +157,7 @@ class Decoder(Chain):
|
||||||
)
|
)
|
||||||
resnet_layers[0].insert(1, attention_layer)
|
resnet_layers[0].insert(1, attention_layer)
|
||||||
for _, layer in zip(range(3), resnet_layers[1:]):
|
for _, layer in zip(range(3), resnet_layers[1:]):
|
||||||
channels: int = layer[-1].out_channels
|
channels: int = layer.layer(-1, Resnet).out_channels
|
||||||
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
|
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Conv2d(
|
Conv2d(
|
||||||
|
|
|
@ -73,7 +73,7 @@ class FreeUResidualConcatenator(fl.Concatenate):
|
||||||
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
|
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
|
||||||
assert len(backbone_scales) == len(skip_scales)
|
assert len(backbone_scales) == len(skip_scales)
|
||||||
assert len(backbone_scales) <= len(target.UpBlocks)
|
assert len(backbone_scales) <= len(target.layer("UpBlocks", fl.Chain))
|
||||||
self.backbone_scales = backbone_scales
|
self.backbone_scales = backbone_scales
|
||||||
self.skip_scales = skip_scales
|
self.skip_scales = skip_scales
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
|
@ -88,7 +88,7 @@ class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
|
||||||
def eject(self) -> None:
|
def eject(self) -> None:
|
||||||
for n in range(len(self.backbone_scales)):
|
for n in range(len(self.backbone_scales)):
|
||||||
block = self.target.UpBlocks[n]
|
block = self.target.layer(("UpBlocks", n), fl.Chain)
|
||||||
concat = block.ensure_find(FreeUResidualConcatenator)
|
concat = block.ensure_find(FreeUResidualConcatenator)
|
||||||
block.replace(concat, ResidualConcatenator(-n - 2))
|
block.replace(concat, ResidualConcatenator(-n - 2))
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
|
@ -307,11 +307,11 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_key_projection(self) -> fl.Linear:
|
def image_key_projection(self) -> fl.Linear:
|
||||||
return self.image_cross_attention.Distribute[1].Linear
|
return self.image_cross_attention.layer(("Distribute", 1, "Linear"), fl.Linear)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_value_projection(self) -> fl.Linear:
|
def image_value_projection(self) -> fl.Linear:
|
||||||
return self.image_cross_attention.Distribute[2].Linear
|
return self.image_cross_attention.layer(("Distribute", 2, "Linear"), fl.Linear)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
|
|
|
@ -9,17 +9,15 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
from refiners.foundationals.latent_diffusion.solvers import Solver
|
from refiners.foundationals.latent_diffusion.solvers import Solver
|
||||||
|
|
||||||
T = TypeVar("T", bound="fl.Module")
|
|
||||||
|
|
||||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionModel(fl.Module, ABC):
|
class LatentDiffusionModel(fl.Module, ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet: fl.Module,
|
unet: fl.Chain,
|
||||||
lda: LatentDiffusionAutoencoder,
|
lda: LatentDiffusionAutoencoder,
|
||||||
clip_text_encoder: fl.Module,
|
clip_text_encoder: fl.Chain,
|
||||||
solver: Solver,
|
solver: Solver,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from typing import Iterable, cast
|
|
||||||
|
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -93,28 +91,31 @@ class Controlnet(Passthrough):
|
||||||
|
|
||||||
# We run the condition encoder at each step. Caching the result
|
# We run the condition encoder at each step. Caching the result
|
||||||
# is not worth it as subsequent runs take virtually no time (FG-374).
|
# is not worth it as subsequent runs take virtually no time (FG-374).
|
||||||
self.DownBlocks[0].append(
|
|
||||||
|
self.layer(("DownBlocks", 0), Chain).append(
|
||||||
Residual(
|
Residual(
|
||||||
UseContext("controlnet", f"condition_{name}"),
|
UseContext("controlnet", f"condition_{name}"),
|
||||||
ConditionEncoder(device=device, dtype=dtype),
|
ConditionEncoder(device=device, dtype=dtype),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for residual_block in self.layers(ResidualBlock):
|
for residual_block in self.layers(ResidualBlock):
|
||||||
chain = residual_block.Chain
|
chain = residual_block.layer("Chain", Chain)
|
||||||
RangeAdapter2d(
|
RangeAdapter2d(
|
||||||
target=chain.Conv2d_1,
|
target=chain.layer("Conv2d_1", Conv2d),
|
||||||
channels=residual_block.out_channels,
|
channels=residual_block.out_channels,
|
||||||
embedding_dim=1280,
|
embedding_dim=1280,
|
||||||
context_key=f"timestep_embedding_{name}",
|
context_key=f"timestep_embedding_{name}",
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
).inject(chain)
|
).inject(chain)
|
||||||
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
|
for n, block in enumerate(self.layer("DownBlocks", DownBlocks)):
|
||||||
assert hasattr(block[0], "out_channels"), (
|
assert isinstance(block, Chain)
|
||||||
|
b0 = block[0]
|
||||||
|
assert hasattr(b0, "out_channels"), (
|
||||||
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
|
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
|
||||||
f" {block[0]} does not."
|
f" {b0} does not."
|
||||||
)
|
)
|
||||||
out_channels: int = block[0].out_channels
|
assert isinstance(out_channels := b0.out_channels, int)
|
||||||
block.append(
|
block.append(
|
||||||
Passthrough(
|
Passthrough(
|
||||||
Conv2d(
|
Conv2d(
|
||||||
|
@ -123,7 +124,7 @@ class Controlnet(Passthrough):
|
||||||
Lambda(self._store_nth_residual(n)),
|
Lambda(self._store_nth_residual(n)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.MiddleBlock.append(
|
self.layer("MiddleBlock", MiddleBlock).append(
|
||||||
Passthrough(
|
Passthrough(
|
||||||
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
|
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
|
||||||
Lambda(self._store_nth_residual(12)),
|
Lambda(self._store_nth_residual(12)),
|
||||||
|
@ -166,6 +167,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||||
self.target.remove(self._controlnet[0])
|
self.target.remove(self._controlnet[0])
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def controlnet(self) -> Controlnet:
|
||||||
|
return self._controlnet[0]
|
||||||
|
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"controlnet": {f"condition_{self.name}": None}}
|
return {"controlnet": {f"condition_{self.name}": None}}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||||
|
|
||||||
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
||||||
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||||
block = self.target.DownBlocks[n]
|
block = self.target.layer(("DownBlocks", n), fl.Chain)
|
||||||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
block.insert_before_type(ResidualAccumulator, feat)
|
block.insert_before_type(ResidualAccumulator, feat)
|
||||||
|
@ -33,5 +33,5 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||||
|
|
||||||
def eject(self: "SD1T2IAdapter") -> None:
|
def eject(self: "SD1T2IAdapter") -> None:
|
||||||
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||||
self.target.DownBlocks[n].remove(feat)
|
self.target.layer(("DownBlocks", n), fl.Chain).remove(feat)
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
|
@ -6,6 +6,11 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
|
from refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
|
||||||
|
from refiners.foundationals.latent_diffusion.unet import (
|
||||||
|
ResidualAccumulator,
|
||||||
|
ResidualBlock,
|
||||||
|
ResidualConcatenator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(fl.Passthrough):
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
@ -22,54 +27,6 @@ class TimestepEncoder(fl.Passthrough):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(fl.Sum):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
num_groups: int = 32,
|
|
||||||
eps: float = 1e-5,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
|
||||||
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
|
|
||||||
raise ValueError("Number of input and output channels must be divisible by num_groups.")
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.num_groups = num_groups
|
|
||||||
self.eps = eps
|
|
||||||
shortcut = (
|
|
||||||
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
|
||||||
if in_channels != out_channels
|
|
||||||
else fl.Identity()
|
|
||||||
)
|
|
||||||
super().__init__(
|
|
||||||
fl.Chain(
|
|
||||||
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
|
||||||
fl.SiLU(),
|
|
||||||
fl.Conv2d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
|
||||||
fl.SiLU(),
|
|
||||||
fl.Conv2d(
|
|
||||||
in_channels=out_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
shortcut,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPLCrossAttention(CrossAttentionBlock2d):
|
class CLIPLCrossAttention(CrossAttentionBlock2d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -205,34 +162,6 @@ class MiddleBlock(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResidualAccumulator(fl.Passthrough):
|
|
||||||
def __init__(self, n: int) -> None:
|
|
||||||
self.n = n
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
fl.Residual(
|
|
||||||
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
|
|
||||||
),
|
|
||||||
fl.SetContext(context="unet", key="residuals", callback=self.update),
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
|
|
||||||
residuals[self.n] = x
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualConcatenator(fl.Chain):
|
|
||||||
def __init__(self, n: int) -> None:
|
|
||||||
self.n = n
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
fl.Concatenate(
|
|
||||||
fl.Identity(),
|
|
||||||
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
|
|
||||||
dim=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SD1UNet(fl.Chain):
|
class SD1UNet(fl.Chain):
|
||||||
"""Stable Diffusion 1.5 U-Net.
|
"""Stable Diffusion 1.5 U-Net.
|
||||||
|
|
||||||
|
@ -275,9 +204,9 @@ class SD1UNet(fl.Chain):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for residual_block in self.layers(ResidualBlock):
|
for residual_block in self.layers(ResidualBlock):
|
||||||
chain = residual_block.Chain
|
chain = residual_block.layer("Chain", fl.Chain)
|
||||||
RangeAdapter2d(
|
RangeAdapter2d(
|
||||||
target=chain.Conv2d_1,
|
target=chain.layer("Conv2d_1", fl.Conv2d),
|
||||||
channels=residual_block.out_channels,
|
channels=residual_block.out_channels,
|
||||||
embedding_dim=1280,
|
embedding_dim=1280,
|
||||||
context_key="timestep_embedding",
|
context_key="timestep_embedding",
|
||||||
|
|
|
@ -2,7 +2,7 @@ from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures
|
from refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,18 +31,19 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
||||||
|
|
||||||
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
|
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
|
||||||
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||||
block = self.target.DownBlocks[n]
|
block = self.target.layer(("DownBlocks", n), fl.Chain)
|
||||||
sanity_check_t2i(block)
|
sanity_check_t2i(block)
|
||||||
block.insert_before_type(ResidualAccumulator, feat)
|
block.insert_before_type(ResidualAccumulator, feat)
|
||||||
|
|
||||||
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
||||||
sanity_check_t2i(self.target.MiddleBlock)
|
mid_block = self.target.layer("MiddleBlock", MiddleBlock)
|
||||||
self.target.MiddleBlock.append(self._features[-1])
|
sanity_check_t2i(mid_block)
|
||||||
|
mid_block.append(self._features[-1])
|
||||||
return super().inject(parent)
|
return super().inject(parent)
|
||||||
|
|
||||||
def eject(self: "SDXLT2IAdapter") -> None:
|
def eject(self: "SDXLT2IAdapter") -> None:
|
||||||
# See `inject` re: `strict=False`
|
# See `inject` re: `strict=False`
|
||||||
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||||
self.target.DownBlocks[n].remove(feat)
|
self.target.layer(("DownBlocks", n), fl.Chain).remove(feat)
|
||||||
self.target.MiddleBlock.remove(self._features[-1])
|
self.target.layer("MiddleBlock", MiddleBlock).remove(self._features[-1])
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
|
@ -72,7 +72,9 @@ class DoubleTextEncoder(fl.Chain):
|
||||||
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
||||||
fl.Lambda(func=self.concatenate_embeddings),
|
fl.Lambda(func=self.concatenate_embeddings),
|
||||||
)
|
)
|
||||||
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
|
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(
|
||||||
|
parent=self.layer("Parallel", fl.Parallel)
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
||||||
return super().__call__(text)
|
return super().__call__(text)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
|
||||||
RangeEncoder,
|
RangeEncoder,
|
||||||
compute_sinusoidal_embedding,
|
compute_sinusoidal_embedding,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
from refiners.foundationals.latent_diffusion.unet import (
|
||||||
ResidualAccumulator,
|
ResidualAccumulator,
|
||||||
ResidualBlock,
|
ResidualBlock,
|
||||||
ResidualConcatenator,
|
ResidualConcatenator,
|
||||||
|
@ -267,9 +267,9 @@ class SDXLUNet(fl.Chain):
|
||||||
OutputBlock(device=device, dtype=dtype),
|
OutputBlock(device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
for residual_block in self.layers(ResidualBlock):
|
for residual_block in self.layers(ResidualBlock):
|
||||||
chain = residual_block.Chain
|
chain = residual_block.layer("Chain", fl.Chain)
|
||||||
RangeAdapter2d(
|
RangeAdapter2d(
|
||||||
target=chain.Conv2d_1,
|
target=chain.layer("Conv2d_1", fl.Conv2d),
|
||||||
channels=residual_block.out_channels,
|
channels=residual_block.out_channels,
|
||||||
embedding_dim=1280,
|
embedding_dim=1280,
|
||||||
context_key="timestep_embedding",
|
context_key="timestep_embedding",
|
||||||
|
|
79
src/refiners/foundationals/latent_diffusion/unet.py
Normal file
79
src/refiners/foundationals/latent_diffusion/unet.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(fl.Sum):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_groups: int = 32,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
|
||||||
|
raise ValueError("Number of input and output channels must be divisible by num_groups.")
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.eps = eps
|
||||||
|
shortcut = (
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else fl.Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
shortcut,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAccumulator(fl.Passthrough):
|
||||||
|
def __init__(self, n: int) -> None:
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
|
||||||
|
),
|
||||||
|
fl.SetContext(context="unet", key="residuals", callback=self.update),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
|
||||||
|
residuals[self.n] = x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConcatenator(fl.Chain):
|
||||||
|
def __init__(self, n: int) -> None:
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
)
|
|
@ -188,6 +188,7 @@ class MaskEncoder(fl.Chain):
|
||||||
def get_no_mask_dense_embedding(
|
def get_no_mask_dense_embedding(
|
||||||
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
||||||
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
|
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
|
||||||
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand(
|
no_mask_embedding = cast(Tensor, self.no_mask_embedding)
|
||||||
|
return no_mask_embedding.reshape(1, -1, 1, 1).expand(
|
||||||
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
|
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,8 +22,8 @@ def chain() -> Chain:
|
||||||
|
|
||||||
|
|
||||||
def test_weighted_module_adapter_insertion(chain: Chain):
|
def test_weighted_module_adapter_insertion(chain: Chain):
|
||||||
parent = chain.Chain
|
parent = chain.layer("Chain", Chain)
|
||||||
adaptee = parent.Linear
|
adaptee = parent.layer("Linear", Linear)
|
||||||
|
|
||||||
adapter = DummyLinearAdapter(adaptee).inject(parent)
|
adapter = DummyLinearAdapter(adaptee).inject(parent)
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ def test_weighted_module_adapter_insertion(chain: Chain):
|
||||||
|
|
||||||
def test_chain_adapter_insertion(chain: Chain):
|
def test_chain_adapter_insertion(chain: Chain):
|
||||||
parent = chain
|
parent = chain
|
||||||
adaptee = parent.Chain
|
adaptee = parent.layer("Chain", Chain)
|
||||||
|
|
||||||
adapter = DummyChainAdapter(adaptee)
|
adapter = DummyChainAdapter(adaptee)
|
||||||
assert adaptee.parent == parent
|
assert adaptee.parent == parent
|
||||||
|
@ -58,20 +58,20 @@ def test_chain_adapter_insertion(chain: Chain):
|
||||||
|
|
||||||
|
|
||||||
def test_weighted_module_adapter_structural_copy(chain: Chain):
|
def test_weighted_module_adapter_structural_copy(chain: Chain):
|
||||||
parent = chain.Chain
|
parent = chain.layer("Chain", Chain)
|
||||||
adaptee = parent.Linear
|
adaptee = parent.layer("Linear", Linear)
|
||||||
|
|
||||||
DummyLinearAdapter(adaptee).inject(parent)
|
DummyLinearAdapter(adaptee).inject(parent)
|
||||||
|
|
||||||
clone = chain.structural_copy()
|
clone = chain.structural_copy()
|
||||||
cloned_adapter = clone.Chain.DummyLinearAdapter
|
cloned_adapter = clone.layer(("Chain", "DummyLinearAdapter"), DummyLinearAdapter)
|
||||||
assert cloned_adapter.parent == clone.Chain
|
assert cloned_adapter.parent == clone.Chain
|
||||||
assert cloned_adapter.target == adaptee
|
assert cloned_adapter.target == adaptee
|
||||||
|
|
||||||
|
|
||||||
def test_chain_adapter_structural_copy(chain: Chain):
|
def test_chain_adapter_structural_copy(chain: Chain):
|
||||||
# Chain adapters cannot be copied by default.
|
# Chain adapters cannot be copied by default.
|
||||||
adapter = DummyChainAdapter(chain.Chain).inject()
|
adapter = DummyChainAdapter(chain.layer("Chain", Chain)).inject()
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
chain.structural_copy()
|
chain.structural_copy()
|
||||||
|
|
|
@ -27,6 +27,7 @@ def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
||||||
assert conv_lora.scale == 1.0
|
assert conv_lora.scale == 1.0
|
||||||
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
||||||
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
||||||
|
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
||||||
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
||||||
# padding is set so the spatial dimensions are preserved
|
# padding is set so the spatial dimensions are preserved
|
||||||
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
||||||
|
@ -39,10 +40,12 @@ def test_scale_setter(lora: LinearLora) -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
||||||
|
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
|
||||||
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
||||||
x = torch.randn(1, 320)
|
x = torch.randn(1, 320)
|
||||||
assert torch.allclose(lora(x), new_lora(x))
|
assert torch.allclose(lora(x), new_lora(x))
|
||||||
|
|
||||||
|
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
||||||
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
||||||
x = torch.randn(1, 16, 64, 64)
|
x = torch.randn(1, 16, 64, 64)
|
||||||
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
||||||
|
|
|
@ -15,8 +15,9 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-
|
||||||
dtype = torch.float64
|
dtype = torch.float64
|
||||||
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
|
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
|
||||||
|
|
||||||
adaptee = chain.RangeEncoder.Linear_1
|
range_encoder = chain.layer("RangeEncoder", RangeEncoder)
|
||||||
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
|
adaptee = range_encoder.layer("Linear_1", Linear)
|
||||||
|
adapter = DummyLinearAdapter(adaptee).inject(range_encoder)
|
||||||
|
|
||||||
assert adapter.parent == chain.RangeEncoder
|
assert adapter.parent == chain.RangeEncoder
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,9 @@ def test_chain_getitem_accessor() -> None:
|
||||||
|
|
||||||
def test_chain_find_parent():
|
def test_chain_find_parent():
|
||||||
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
chain = fl.Chain(fl.Chain(fl.Linear(1, 1)))
|
||||||
|
subchain = chain.layer("Chain", fl.Chain)
|
||||||
|
|
||||||
assert chain.find_parent(chain.Chain.Linear) == chain.Chain
|
assert chain.find_parent(subchain.layer("Linear", fl.Linear)) == subchain
|
||||||
assert chain.find_parent(fl.Linear(1, 1)) is None
|
assert chain.find_parent(fl.Linear(1, 1)) is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,17 +65,20 @@ def test_chain_walk() -> None:
|
||||||
fl.Chain(),
|
fl.Chain(),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert list(chain.walk()) == [(chain.Sum, chain), (chain.Chain, chain)]
|
sum_ = chain.layer("Sum", fl.Sum)
|
||||||
|
sum_chain = sum_.layer("Chain", fl.Chain)
|
||||||
|
|
||||||
|
assert list(chain.walk()) == [(sum_, chain), (chain.Chain, chain)]
|
||||||
assert list(chain.walk(fl.Linear)) == [
|
assert list(chain.walk(fl.Linear)) == [
|
||||||
(chain.Sum.Chain.Linear, chain.Sum.Chain),
|
(sum_chain.Linear, sum_chain),
|
||||||
(chain.Sum.Linear, chain.Sum),
|
(sum_.Linear, sum_),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert list(chain.walk(recurse=True)) == [
|
assert list(chain.walk(recurse=True)) == [
|
||||||
(chain.Sum, chain),
|
(sum_, chain),
|
||||||
(chain.Sum.Chain, chain.Sum),
|
(sum_chain, sum_),
|
||||||
(chain.Sum.Chain.Linear, chain.Sum.Chain),
|
(sum_chain.Linear, sum_chain),
|
||||||
(chain.Sum.Linear, chain.Sum),
|
(sum_.Linear, sum_),
|
||||||
(chain.Chain, chain),
|
(chain.Chain, chain),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -98,6 +102,29 @@ def test_chain_walk_stop_iteration() -> None:
|
||||||
assert len(list(chain.walk(predicate))) == 1
|
assert len(list(chain.walk(predicate))) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_chain_layer() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Sum(fl.Chain(), fl.Chain()),
|
||||||
|
)
|
||||||
|
|
||||||
|
sum_ = chain.layer(0, fl.Sum)
|
||||||
|
assert chain.layer("Sum", fl.Sum) == sum_
|
||||||
|
assert chain.layer("Sum", fl.Chain) == sum_
|
||||||
|
|
||||||
|
chain_2 = chain.layer((0, 1), fl.Chain)
|
||||||
|
assert chain.layer((0, 1)) == chain_2
|
||||||
|
assert chain.layer((0, "Chain_2"), fl.Chain) == chain_2
|
||||||
|
assert chain.layer(("Sum", "Chain_2"), fl.Chain) == chain_2
|
||||||
|
|
||||||
|
assert chain.layer((), fl.Chain) == chain
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
chain.layer((0, 1), fl.Sum)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
chain.layer((), fl.Sum)
|
||||||
|
|
||||||
|
|
||||||
def test_chain_layers() -> None:
|
def test_chain_layers() -> None:
|
||||||
chain = fl.Chain(
|
chain = fl.Chain(
|
||||||
fl.Chain(fl.Chain(fl.Chain())),
|
fl.Chain(fl.Chain(fl.Chain())),
|
||||||
|
@ -214,11 +241,12 @@ def test_chain_replace() -> None:
|
||||||
fl.Linear(1, 1),
|
fl.Linear(1, 1),
|
||||||
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
|
fl.Chain(fl.Linear(1, 1), fl.Linear(1, 1)),
|
||||||
)
|
)
|
||||||
|
subchain = chain.layer("Chain", fl.Chain)
|
||||||
|
|
||||||
assert isinstance(chain.Chain[1], fl.Linear)
|
assert isinstance(subchain[1], fl.Linear)
|
||||||
chain.Chain.replace(chain.Chain[1], fl.Conv2d(1, 1, 1))
|
subchain.replace(subchain[1], fl.Conv2d(1, 1, 1))
|
||||||
assert len(chain) == 3
|
assert len(chain) == 3
|
||||||
assert isinstance(chain.Chain[1], fl.Conv2d)
|
assert isinstance(subchain[1], fl.Conv2d)
|
||||||
|
|
||||||
|
|
||||||
def test_chain_structural_copy() -> None:
|
def test_chain_structural_copy() -> None:
|
||||||
|
@ -236,15 +264,18 @@ def test_chain_structural_copy() -> None:
|
||||||
|
|
||||||
m2 = m.structural_copy()
|
m2 = m.structural_copy()
|
||||||
|
|
||||||
assert m.Linear == m2.Linear
|
m_sum = m.layer("Sum", fl.Sum)
|
||||||
assert m.Sum.Linear_1 == m2.Sum.Linear_1
|
m2_sum = m2.layer("Sum", fl.Sum)
|
||||||
assert m.Sum.Linear_2 == m2.Sum.Linear_2
|
|
||||||
|
|
||||||
assert m.Sum != m2.Sum
|
assert m.Linear == m2.Linear
|
||||||
|
assert m_sum.Linear_1 == m2_sum.Linear_1
|
||||||
|
assert m_sum.Linear_2 == m2_sum.Linear_2
|
||||||
|
|
||||||
|
assert m_sum != m2_sum
|
||||||
assert m != m2
|
assert m != m2
|
||||||
|
|
||||||
assert m.Sum.parent == m
|
assert m_sum.parent == m
|
||||||
assert m2.Sum.parent == m2
|
assert m2_sum.parent == m2
|
||||||
|
|
||||||
y2 = m2(x)
|
y2 = m2(x)
|
||||||
assert y2.shape == (7, 12)
|
assert y2.shape == (7, 12)
|
||||||
|
|
|
@ -10,9 +10,12 @@ def test_module_get_path() -> None:
|
||||||
fl.Sum(),
|
fl.Sum(),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1) == "Chain.Sum_1.Linear_2"
|
sum_1 = chain.layer("Sum_1", fl.Sum)
|
||||||
assert chain.Sum_1.Linear_2.get_path(parent=chain.Sum_1, top=chain.Sum_1) == "Sum.Linear_2"
|
linear_2 = sum_1.layer("Linear_2", fl.Linear)
|
||||||
assert chain.Sum_1.get_path() == "Chain.Sum_1"
|
|
||||||
|
assert linear_2.get_path(parent=sum_1) == "Chain.Sum_1.Linear_2"
|
||||||
|
assert linear_2.get_path(parent=sum_1, top=sum_1) == "Sum.Linear_2"
|
||||||
|
assert sum_1.get_path() == "Chain.Sum_1"
|
||||||
|
|
||||||
|
|
||||||
def test_module_basic_attributes() -> None:
|
def test_module_basic_attributes() -> None:
|
||||||
|
|
|
@ -85,13 +85,14 @@ def cat_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.
|
||||||
|
|
||||||
|
|
||||||
def test_tokenizer_with_special_character():
|
def test_tokenizer_with_special_character():
|
||||||
clip_tokenizer = fl.Chain(CLIPTokenizer())
|
clip_tokenizer_chain = fl.Chain(CLIPTokenizer())
|
||||||
token_extender = TokenExtender(clip_tokenizer.CLIPTokenizer)
|
original_clip_tokenizer = clip_tokenizer_chain.layer("CLIPTokenizer", CLIPTokenizer)
|
||||||
new_token_id = max(clip_tokenizer.CLIPTokenizer.token_to_id_mapping.values()) + 42
|
token_extender = TokenExtender(original_clip_tokenizer)
|
||||||
|
new_token_id = max(original_clip_tokenizer.token_to_id_mapping.values()) + 42
|
||||||
token_extender.add_token("*", new_token_id)
|
token_extender.add_token("*", new_token_id)
|
||||||
token_extender.inject(clip_tokenizer)
|
token_extender.inject(clip_tokenizer_chain)
|
||||||
|
|
||||||
adapted_clip_tokenizer = clip_tokenizer.ensure_find(CLIPTokenizer)
|
adapted_clip_tokenizer = clip_tokenizer_chain.ensure_find(CLIPTokenizer)
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
adapted_clip_tokenizer.encode("*"),
|
adapted_clip_tokenizer.encode("*"),
|
||||||
|
|
|
@ -4,6 +4,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
|
from refiners.fluxion.layers import Chain
|
||||||
from refiners.fluxion.utils import no_grad
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
||||||
|
@ -33,7 +34,7 @@ def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
num_blocks = len(unet.UpBlocks)
|
num_blocks = len(unet.layer("UpBlocks", Chain))
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
|
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
|
||||||
|
|
|
@ -16,10 +16,11 @@ from tests.foundationals.segment_anything.utils import (
|
||||||
)
|
)
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_tensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention, RelativePositionAttention
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTransformerLayer
|
||||||
|
|
||||||
|
@ -104,17 +105,22 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
||||||
manual_seed(seed=0)
|
manual_seed(seed=0)
|
||||||
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
|
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
|
||||||
|
|
||||||
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn) # type: ignore
|
attention = cast(nn.Module, facebook_sam_h.image_encoder.blocks[0].attn)
|
||||||
|
|
||||||
refiners_attention = FusedSelfAttention(
|
refiners_attention = FusedSelfAttention(
|
||||||
embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
|
embedding_dim=1280, num_heads=16, spatial_size=(14, 14), device=facebook_sam_h.device
|
||||||
)
|
)
|
||||||
refiners_attention.Linear_1.weight = attention.qkv.weight # type: ignore
|
|
||||||
refiners_attention.Linear_1.bias = attention.qkv.bias # type: ignore
|
rpa = refiners_attention.layer("RelativePositionAttention", RelativePositionAttention)
|
||||||
refiners_attention.Linear_2.weight = attention.proj.weight # type: ignore
|
linear_1 = refiners_attention.layer("Linear_1", fl.Linear)
|
||||||
refiners_attention.Linear_2.bias = attention.proj.bias # type: ignore
|
linear_2 = refiners_attention.layer("Linear_2", fl.Linear)
|
||||||
refiners_attention.RelativePositionAttention.horizontal_embedding = attention.rel_pos_w
|
|
||||||
refiners_attention.RelativePositionAttention.vertical_embedding = attention.rel_pos_h
|
linear_1.weight = attention.qkv.weight
|
||||||
|
linear_1.bias = attention.qkv.bias
|
||||||
|
linear_2.weight = attention.proj.weight
|
||||||
|
linear_2.bias = attention.proj.bias
|
||||||
|
rpa.horizontal_embedding = attention.rel_pos_w
|
||||||
|
rpa.vertical_embedding = attention.rel_pos_h
|
||||||
|
|
||||||
y_1 = attention(x)
|
y_1 = attention(x)
|
||||||
assert y_1.shape == x.shape
|
assert y_1.shape == x.shape
|
||||||
|
|
Loading…
Reference in a new issue