remove structural_attrs

This commit is contained in:
Benjamin Trom 2023-09-14 14:34:31 +02:00
parent 121ef4df39
commit 1cb798e8ae
16 changed files with 21 additions and 147 deletions

View file

@ -12,8 +12,6 @@ TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PE
class Lora(fl.Chain): class Lora(fl.Chain):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
@ -56,8 +54,6 @@ class Lora(fl.Chain):
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]): class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__( def __init__(
self, self,
target: fl.Linear, target: fl.Linear,

View file

@ -82,18 +82,6 @@ class ScaledDotProductAttention(Module):
class Attention(Chain): class Attention(Chain):
structural_attrs = [
"embedding_dim",
"num_heads",
"heads_dim",
"key_embedding_dim",
"value_embedding_dim",
"inner_dim",
"use_bias",
"is_causal",
"is_optimized",
]
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
@ -180,8 +168,6 @@ class SelfAttention(Attention):
class SelfAttention2d(SelfAttention): class SelfAttention2d(SelfAttention):
structural_attrs = Attention.structural_attrs + ["channels"]
def __init__( def __init__(
self, self,
channels: int, channels: int,

View file

@ -44,8 +44,6 @@ def generate_unique_names(
class UseContext(ContextModule): class UseContext(ContextModule):
structural_attrs = ["context", "key", "func"]
def __init__(self, context: str, key: str) -> None: def __init__(self, context: str, key: str) -> None:
super().__init__() super().__init__()
self.context = context self.context = context
@ -74,8 +72,6 @@ class SetContext(ContextModule):
#TODO Is there a way to create the context if it doesn't exist? #TODO Is there a way to create the context if it doesn't exist?
""" """
structural_attrs = ["context", "key", "callback"]
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None: def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
super().__init__() super().__init__()
self.context = context self.context = context
@ -374,29 +370,16 @@ class Chain(ContextModule):
Such copies can be adapted without disrupting the base model, but do not Such copies can be adapted without disrupting the base model, but do not
require extra GPU memory since the weights are in the leaves and hence not copied. require extra GPU memory since the weights are in the leaves and hence not copied.
This assumes all subclasses define the class variable `structural_attrs` which
contains a list of basic attributes set in the constructor. In complicated cases
it may be required to overwrite that method.
""" """
if hasattr(self, "_pre_structural_copy"): if hasattr(self, "_pre_structural_copy"):
self._pre_structural_copy() self._pre_structural_copy()
modules = [structural_copy(m) for m in self] modules = [structural_copy(m) for m in self]
clone = super().structural_copy()
# Instantiate the right subclass, but do not initialize. clone._provider = ContextProvider()
clone = object.__new__(self.__class__)
# Copy all basic attributes of the class declared in `structural_attrs`.
for k in self.__class__.structural_attrs:
setattr(clone, k, getattr(self, k))
# Call constructor of Chain, which among other things refreshes the context tree.
Chain.__init__(clone, *modules)
for module in modules: for module in modules:
if isinstance(module, ContextModule): clone.append(module=module)
module._set_parent(clone)
if hasattr(clone, "_post_structural_copy"): if hasattr(clone, "_post_structural_copy"):
clone._post_structural_copy(self) clone._post_structural_copy(self)
@ -477,7 +460,6 @@ class Breakpoint(ContextModule):
class Concatenate(Chain): class Concatenate(Chain):
_tag = "CAT" _tag = "CAT"
structural_attrs = ["dim"]
def __init__(self, *modules: Module, dim: int = 0) -> None: def __init__(self, *modules: Module, dim: int = 0) -> None:
super().__init__(*modules) super().__init__(*modules)

View file

@ -17,8 +17,6 @@ class Converter(ContextModule):
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True. Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
""" """
structural_attrs = ["set_device", "set_dtype"]
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None: def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
super().__init__() super().__init__()
self.set_device = set_device self.set_device = set_device

View file

@ -1,5 +1,7 @@
from inspect import signature, Parameter from inspect import signature, Parameter
import sys
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import Any, Generator, TypeVar, TypedDict, cast from typing import Any, Generator, TypeVar, TypedDict, cast
from torch import device as Device, dtype as DType from torch import device as Device, dtype as DType
@ -96,11 +98,6 @@ class ContextModule(Module):
_parent: "list[Chain]" _parent: "list[Chain]"
_can_refresh_parent: bool = True # see usage in Adapter and Chain _can_refresh_parent: bool = True # see usage in Adapter and Chain
# Contains simple attributes set on the instance by `__init__` in subclasses
# and copied by `structural_copy`. Note that is not the case of `device` since
# Chain's __init__ takes care of it.
structural_attrs: list[str] = []
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs) super().__init__(*args, *kwargs)
self._parent = [] self._parent = []
@ -139,9 +136,20 @@ class ContextModule(Module):
def structural_copy(self: TContextModule) -> TContextModule: def structural_copy(self: TContextModule) -> TContextModule:
clone = object.__new__(self.__class__) clone = object.__new__(self.__class__)
for k in self.__class__.structural_attrs:
not_torch_attributes = [
key
for key, value in self.__dict__.items()
if not key.startswith("_")
and isinstance(sys.modules.get(type(value).__module__), ModuleType)
and "torch" not in sys.modules[type(value).__module__].__name__
]
for k in not_torch_attributes:
setattr(clone, k, getattr(self, k)) setattr(clone, k, getattr(self, k))
ContextModule.__init__(clone)
ContextModule.__init__(self=clone)
return clone return clone

View file

@ -9,8 +9,6 @@ from torch import Tensor, Size, device as Device, dtype as DType
class Downsample(Chain): class Downsample(Chain):
structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"]
def __init__( def __init__(
self, self,
channels: int, channels: int,
@ -59,8 +57,6 @@ class Interpolate(Module):
class Upsample(Chain): class Upsample(Chain):
structural_attrs = ["channels", "upsample_factor"]
def __init__( def __init__(
self, self,
channels: int, channels: int,

View file

@ -3,8 +3,6 @@ import refiners.fluxion.layers as fl
class PositionalEncoder(fl.Chain): class PositionalEncoder(fl.Chain):
structural_attrs = ["max_sequence_length", "embedding_dim"]
def __init__( def __init__(
self, self,
max_sequence_length: int, max_sequence_length: int,
@ -33,8 +31,6 @@ class PositionalEncoder(fl.Chain):
class FeedForward(fl.Chain): class FeedForward(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,

View file

@ -4,8 +4,6 @@ from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
class ClassEncoder(fl.Chain): class ClassEncoder(fl.Chain):
structural_attrs = ["embedding_dim"]
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
@ -20,8 +18,6 @@ class ClassEncoder(fl.Chain):
class PatchEncoder(fl.Chain): class PatchEncoder(fl.Chain):
structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"]
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -50,8 +46,6 @@ class PatchEncoder(fl.Chain):
class TransformerLayer(fl.Chain): class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"]
def __init__( def __init__(
self, self,
embedding_dim: int = 768, embedding_dim: int = 768,
@ -80,8 +74,6 @@ class TransformerLayer(fl.Chain):
class ViTEmbeddings(fl.Chain): class ViTEmbeddings(fl.Chain):
structural_attrs = ["image_size", "embedding_dim", "patch_size"]
def __init__( def __init__(
self, self,
image_size: int = 224, image_size: int = 224,
@ -90,6 +82,9 @@ class ViTEmbeddings(fl.Chain):
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.image_size = image_size
self.embedding_dim = embedding_dim
self.patch_size = patch_size
super().__init__( super().__init__(
fl.Concatenate( fl.Concatenate(
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype), ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
@ -118,16 +113,6 @@ class ViTEmbeddings(fl.Chain):
class CLIPImageEncoder(fl.Chain): class CLIPImageEncoder(fl.Chain):
structural_attrs = [
"image_size",
"embedding_dim",
"output_dim",
"patch_size",
"num_layers",
"num_attention_heads",
"feedforward_dim",
]
def __init__( def __init__(
self, self,
image_size: int = 224, image_size: int = 224,

View file

@ -5,8 +5,6 @@ from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class TokenEncoder(fl.Embedding): class TokenEncoder(fl.Embedding):
structural_attrs = ["vocabulary_size", "embedding_dim"]
def __init__( def __init__(
self, self,
vocabulary_size: int, vocabulary_size: int,
@ -25,8 +23,6 @@ class TokenEncoder(fl.Embedding):
class TransformerLayer(fl.Chain): class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
@ -74,17 +70,6 @@ class TransformerLayer(fl.Chain):
class CLIPTextEncoder(fl.Chain): class CLIPTextEncoder(fl.Chain):
structural_attrs = [
"embedding_dim",
"max_sequence_length",
"vocabulary_size",
"num_layers",
"num_attention_heads",
"feedforward_dim",
"layer_norm_eps",
"use_quick_gelu",
]
def __init__( def __init__(
self, self,
embedding_dim: int = 768, embedding_dim: int = 768,

View file

@ -17,8 +17,6 @@ from PIL import Image
class Resnet(Sum): class Resnet(Sum):
structural_attrs = ["in_channels", "out_channels"]
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -125,8 +123,6 @@ class Encoder(Chain):
class Decoder(Chain): class Decoder(Chain):
structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512] self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
self.latent_dim: int = 4 self.latent_dim: int = 4
@ -196,7 +192,6 @@ class Decoder(Chain):
class LatentDiffusionAutoencoder(Chain): class LatentDiffusionAutoencoder(Chain):
structural_attrs = ["encoder_scale"]
encoder_scale = 0.18125 encoder_scale = 0.18125
def __init__( def __init__(

View file

@ -23,8 +23,6 @@ from refiners.fluxion.layers import (
class CrossAttentionBlock(Chain): class CrossAttentionBlock(Chain):
structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"]
def __init__( def __init__(
self, self,
embedding_dim: int, embedding_dim: int,
@ -85,8 +83,6 @@ class CrossAttentionBlock(Chain):
class StatefulFlatten(Chain): class StatefulFlatten(Chain):
structural_attrs = ["start_dim", "end_dim"]
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None: def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
self.start_dim = start_dim self.start_dim = start_dim
self.end_dim = end_dim self.end_dim = end_dim
@ -103,19 +99,6 @@ class StatefulFlatten(Chain):
class CrossAttentionBlock2d(Sum): class CrossAttentionBlock2d(Sum):
structural_attrs = [
"channels",
"in_channels",
"out_channels",
"context_embedding_dim",
"num_attention_heads",
"num_attention_layers",
"num_groups",
"context_key",
"use_linear_projection",
"projection_type",
]
def __init__( def __init__(
self, self,
channels: int, channels: int,

View file

@ -22,8 +22,6 @@ TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
class ImageProjection(fl.Chain): class ImageProjection(fl.Chain):
structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"]
def __init__( def __init__(
self, self,
clip_image_embedding_dim: int = 1024, clip_image_embedding_dim: int = 1024,
@ -57,8 +55,6 @@ class InjectionPoint(fl.Chain):
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"]
def __init__( def __init__(
self, self,
target: fl.Attention, target: fl.Attention,

View file

@ -21,8 +21,6 @@ def compute_sinusoidal_embedding(
class RangeEncoder(fl.Chain): class RangeEncoder(fl.Chain):
structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"]
def __init__( def __init__(
self, self,
sinuosidal_embedding_dim: int, sinuosidal_embedding_dim: int,
@ -45,8 +43,6 @@ class RangeEncoder(fl.Chain):
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]): class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
structural_attrs = ["channels", "embedding_dim", "context_key"]
def __init__( def __init__(
self, self,
target: fl.Conv2d, target: fl.Conv2d,

View file

@ -19,8 +19,6 @@ class ConditionEncoder(Chain):
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor. Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
""" """
structural_attrs = ["out_channels"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.out_channels = (16, 32, 96, 256) self.out_channels = (16, 32, 96, 256)
super().__init__( super().__init__(
@ -72,8 +70,6 @@ class ConditionEncoder(Chain):
class Controlnet(Passthrough): class Controlnet(Passthrough):
structural_attrs = ["scale", "name"]
def __init__( def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
) -> None: ) -> None:

View file

@ -24,8 +24,6 @@ class TimestepEncoder(fl.Passthrough):
class ResidualBlock(fl.Sum): class ResidualBlock(fl.Sum):
structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"]
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -92,8 +90,6 @@ class CLIPLCrossAttention(CrossAttentionBlock2d):
class DownBlocks(fl.Chain): class DownBlocks(fl.Chain):
structural_attrs = ["in_channels"]
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -211,8 +207,6 @@ class MiddleBlock(fl.Chain):
class ResidualAccumulator(fl.Passthrough): class ResidualAccumulator(fl.Passthrough):
structural_attrs = ["n"]
def __init__(self, n: int) -> None: def __init__(self, n: int) -> None:
self.n = n self.n = n
@ -228,8 +222,6 @@ class ResidualAccumulator(fl.Passthrough):
class ResidualConcatenator(fl.Chain): class ResidualConcatenator(fl.Chain):
structural_attrs = ["n"]
def __init__(self, n: int) -> None: def __init__(self, n: int) -> None:
self.n = n self.n = n
@ -243,8 +235,6 @@ class ResidualConcatenator(fl.Chain):
class SD1UNet(fl.Chain): class SD1UNet(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels self.in_channels = in_channels
super().__init__( super().__init__(

View file

@ -16,8 +16,6 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
class TextTimeEmbedding(fl.Chain): class TextTimeEmbedding(fl.Chain):
structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280 self.timestep_embedding_dim = 1280
self.time_ids_embedding_dim = 256 self.time_ids_embedding_dim = 256
@ -54,8 +52,6 @@ class TextTimeEmbedding(fl.Chain):
class TimestepEncoder(fl.Passthrough): class TimestepEncoder(fl.Passthrough):
structural_attrs = ["timestep_embedding_dim"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280 self.timestep_embedding_dim = 1280
super().__init__( super().__init__(
@ -98,8 +94,6 @@ class SDXLCrossAttention(CrossAttentionBlock2d):
class DownBlocks(fl.Chain): class DownBlocks(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels self.in_channels = in_channels
@ -158,8 +152,6 @@ class DownBlocks(fl.Chain):
class UpBlocks(fl.Chain): class UpBlocks(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
first_blocks = [ first_blocks = [
fl.Chain( fl.Chain(
@ -225,8 +217,6 @@ class UpBlocks(fl.Chain):
class MiddleBlock(fl.Chain): class MiddleBlock(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__( super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype), ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
@ -238,8 +228,6 @@ class MiddleBlock(fl.Chain):
class OutputBlock(fl.Chain): class OutputBlock(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__( super().__init__(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype), fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
@ -249,8 +237,6 @@ class OutputBlock(fl.Chain):
class SDXLUNet(fl.Chain): class SDXLUNet(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels self.in_channels = in_channels
super().__init__( super().__init__(