remove structural_attrs

This commit is contained in:
Benjamin Trom 2023-09-14 14:34:31 +02:00
parent 121ef4df39
commit 1cb798e8ae
16 changed files with 21 additions and 147 deletions

View file

@ -12,8 +12,6 @@ TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PE
class Lora(fl.Chain):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__(
self,
in_features: int,
@ -56,8 +54,6 @@ class Lora(fl.Chain):
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__(
self,
target: fl.Linear,

View file

@ -82,18 +82,6 @@ class ScaledDotProductAttention(Module):
class Attention(Chain):
structural_attrs = [
"embedding_dim",
"num_heads",
"heads_dim",
"key_embedding_dim",
"value_embedding_dim",
"inner_dim",
"use_bias",
"is_causal",
"is_optimized",
]
def __init__(
self,
embedding_dim: int,
@ -180,8 +168,6 @@ class SelfAttention(Attention):
class SelfAttention2d(SelfAttention):
structural_attrs = Attention.structural_attrs + ["channels"]
def __init__(
self,
channels: int,

View file

@ -44,8 +44,6 @@ def generate_unique_names(
class UseContext(ContextModule):
structural_attrs = ["context", "key", "func"]
def __init__(self, context: str, key: str) -> None:
super().__init__()
self.context = context
@ -74,8 +72,6 @@ class SetContext(ContextModule):
#TODO Is there a way to create the context if it doesn't exist?
"""
structural_attrs = ["context", "key", "callback"]
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
super().__init__()
self.context = context
@ -374,29 +370,16 @@ class Chain(ContextModule):
Such copies can be adapted without disrupting the base model, but do not
require extra GPU memory since the weights are in the leaves and hence not copied.
This assumes all subclasses define the class variable `structural_attrs` which
contains a list of basic attributes set in the constructor. In complicated cases
it may be required to overwrite that method.
"""
if hasattr(self, "_pre_structural_copy"):
self._pre_structural_copy()
modules = [structural_copy(m) for m in self]
# Instantiate the right subclass, but do not initialize.
clone = object.__new__(self.__class__)
# Copy all basic attributes of the class declared in `structural_attrs`.
for k in self.__class__.structural_attrs:
setattr(clone, k, getattr(self, k))
# Call constructor of Chain, which among other things refreshes the context tree.
Chain.__init__(clone, *modules)
clone = super().structural_copy()
clone._provider = ContextProvider()
for module in modules:
if isinstance(module, ContextModule):
module._set_parent(clone)
clone.append(module=module)
if hasattr(clone, "_post_structural_copy"):
clone._post_structural_copy(self)
@ -477,7 +460,6 @@ class Breakpoint(ContextModule):
class Concatenate(Chain):
_tag = "CAT"
structural_attrs = ["dim"]
def __init__(self, *modules: Module, dim: int = 0) -> None:
super().__init__(*modules)

View file

@ -17,8 +17,6 @@ class Converter(ContextModule):
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
"""
structural_attrs = ["set_device", "set_dtype"]
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
super().__init__()
self.set_device = set_device

View file

@ -1,5 +1,7 @@
from inspect import signature, Parameter
import sys
from pathlib import Path
from types import ModuleType
from typing import Any, Generator, TypeVar, TypedDict, cast
from torch import device as Device, dtype as DType
@ -96,11 +98,6 @@ class ContextModule(Module):
_parent: "list[Chain]"
_can_refresh_parent: bool = True # see usage in Adapter and Chain
# Contains simple attributes set on the instance by `__init__` in subclasses
# and copied by `structural_copy`. Note that is not the case of `device` since
# Chain's __init__ takes care of it.
structural_attrs: list[str] = []
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs)
self._parent = []
@ -139,9 +136,20 @@ class ContextModule(Module):
def structural_copy(self: TContextModule) -> TContextModule:
clone = object.__new__(self.__class__)
for k in self.__class__.structural_attrs:
not_torch_attributes = [
key
for key, value in self.__dict__.items()
if not key.startswith("_")
and isinstance(sys.modules.get(type(value).__module__), ModuleType)
and "torch" not in sys.modules[type(value).__module__].__name__
]
for k in not_torch_attributes:
setattr(clone, k, getattr(self, k))
ContextModule.__init__(clone)
ContextModule.__init__(self=clone)
return clone

View file

@ -9,8 +9,6 @@ from torch import Tensor, Size, device as Device, dtype as DType
class Downsample(Chain):
structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"]
def __init__(
self,
channels: int,
@ -59,8 +57,6 @@ class Interpolate(Module):
class Upsample(Chain):
structural_attrs = ["channels", "upsample_factor"]
def __init__(
self,
channels: int,

View file

@ -3,8 +3,6 @@ import refiners.fluxion.layers as fl
class PositionalEncoder(fl.Chain):
structural_attrs = ["max_sequence_length", "embedding_dim"]
def __init__(
self,
max_sequence_length: int,
@ -33,8 +31,6 @@ class PositionalEncoder(fl.Chain):
class FeedForward(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim"]
def __init__(
self,
embedding_dim: int,

View file

@ -4,8 +4,6 @@ from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
class ClassEncoder(fl.Chain):
structural_attrs = ["embedding_dim"]
def __init__(
self,
embedding_dim: int,
@ -20,8 +18,6 @@ class ClassEncoder(fl.Chain):
class PatchEncoder(fl.Chain):
structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"]
def __init__(
self,
in_channels: int,
@ -50,8 +46,6 @@ class PatchEncoder(fl.Chain):
class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"]
def __init__(
self,
embedding_dim: int = 768,
@ -80,8 +74,6 @@ class TransformerLayer(fl.Chain):
class ViTEmbeddings(fl.Chain):
structural_attrs = ["image_size", "embedding_dim", "patch_size"]
def __init__(
self,
image_size: int = 224,
@ -90,6 +82,9 @@ class ViTEmbeddings(fl.Chain):
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.image_size = image_size
self.embedding_dim = embedding_dim
self.patch_size = patch_size
super().__init__(
fl.Concatenate(
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
@ -118,16 +113,6 @@ class ViTEmbeddings(fl.Chain):
class CLIPImageEncoder(fl.Chain):
structural_attrs = [
"image_size",
"embedding_dim",
"output_dim",
"patch_size",
"num_layers",
"num_attention_heads",
"feedforward_dim",
]
def __init__(
self,
image_size: int = 224,

View file

@ -5,8 +5,6 @@ from refiners.foundationals.clip.tokenizer import CLIPTokenizer
class TokenEncoder(fl.Embedding):
structural_attrs = ["vocabulary_size", "embedding_dim"]
def __init__(
self,
vocabulary_size: int,
@ -25,8 +23,6 @@ class TokenEncoder(fl.Embedding):
class TransformerLayer(fl.Chain):
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
def __init__(
self,
embedding_dim: int,
@ -74,17 +70,6 @@ class TransformerLayer(fl.Chain):
class CLIPTextEncoder(fl.Chain):
structural_attrs = [
"embedding_dim",
"max_sequence_length",
"vocabulary_size",
"num_layers",
"num_attention_heads",
"feedforward_dim",
"layer_norm_eps",
"use_quick_gelu",
]
def __init__(
self,
embedding_dim: int = 768,

View file

@ -17,8 +17,6 @@ from PIL import Image
class Resnet(Sum):
structural_attrs = ["in_channels", "out_channels"]
def __init__(
self,
in_channels: int,
@ -125,8 +123,6 @@ class Encoder(Chain):
class Decoder(Chain):
structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
self.latent_dim: int = 4
@ -196,7 +192,6 @@ class Decoder(Chain):
class LatentDiffusionAutoencoder(Chain):
structural_attrs = ["encoder_scale"]
encoder_scale = 0.18125
def __init__(

View file

@ -23,8 +23,6 @@ from refiners.fluxion.layers import (
class CrossAttentionBlock(Chain):
structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"]
def __init__(
self,
embedding_dim: int,
@ -85,8 +83,6 @@ class CrossAttentionBlock(Chain):
class StatefulFlatten(Chain):
structural_attrs = ["start_dim", "end_dim"]
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
self.start_dim = start_dim
self.end_dim = end_dim
@ -103,19 +99,6 @@ class StatefulFlatten(Chain):
class CrossAttentionBlock2d(Sum):
structural_attrs = [
"channels",
"in_channels",
"out_channels",
"context_embedding_dim",
"num_attention_heads",
"num_attention_layers",
"num_groups",
"context_key",
"use_linear_projection",
"projection_type",
]
def __init__(
self,
channels: int,

View file

@ -22,8 +22,6 @@ TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
class ImageProjection(fl.Chain):
structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"]
def __init__(
self,
clip_image_embedding_dim: int = 1024,
@ -57,8 +55,6 @@ class InjectionPoint(fl.Chain):
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"]
def __init__(
self,
target: fl.Attention,

View file

@ -21,8 +21,6 @@ def compute_sinusoidal_embedding(
class RangeEncoder(fl.Chain):
structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"]
def __init__(
self,
sinuosidal_embedding_dim: int,
@ -45,8 +43,6 @@ class RangeEncoder(fl.Chain):
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
structural_attrs = ["channels", "embedding_dim", "context_key"]
def __init__(
self,
target: fl.Conv2d,

View file

@ -19,8 +19,6 @@ class ConditionEncoder(Chain):
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
"""
structural_attrs = ["out_channels"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.out_channels = (16, 32, 96, 256)
super().__init__(
@ -72,8 +70,6 @@ class ConditionEncoder(Chain):
class Controlnet(Passthrough):
structural_attrs = ["scale", "name"]
def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
) -> None:

View file

@ -24,8 +24,6 @@ class TimestepEncoder(fl.Passthrough):
class ResidualBlock(fl.Sum):
structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"]
def __init__(
self,
in_channels: int,
@ -92,8 +90,6 @@ class CLIPLCrossAttention(CrossAttentionBlock2d):
class DownBlocks(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(
self,
in_channels: int,
@ -211,8 +207,6 @@ class MiddleBlock(fl.Chain):
class ResidualAccumulator(fl.Passthrough):
structural_attrs = ["n"]
def __init__(self, n: int) -> None:
self.n = n
@ -228,8 +222,6 @@ class ResidualAccumulator(fl.Passthrough):
class ResidualConcatenator(fl.Chain):
structural_attrs = ["n"]
def __init__(self, n: int) -> None:
self.n = n
@ -243,8 +235,6 @@ class ResidualConcatenator(fl.Chain):
class SD1UNet(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(

View file

@ -16,8 +16,6 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
class TextTimeEmbedding(fl.Chain):
structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
self.time_ids_embedding_dim = 256
@ -54,8 +52,6 @@ class TextTimeEmbedding(fl.Chain):
class TimestepEncoder(fl.Passthrough):
structural_attrs = ["timestep_embedding_dim"]
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
super().__init__(
@ -98,8 +94,6 @@ class SDXLCrossAttention(CrossAttentionBlock2d):
class DownBlocks(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
@ -158,8 +152,6 @@ class DownBlocks(fl.Chain):
class UpBlocks(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
first_blocks = [
fl.Chain(
@ -225,8 +217,6 @@ class UpBlocks(fl.Chain):
class MiddleBlock(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
@ -238,8 +228,6 @@ class MiddleBlock(fl.Chain):
class OutputBlock(fl.Chain):
structural_attrs = []
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
@ -249,8 +237,6 @@ class OutputBlock(fl.Chain):
class SDXLUNet(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(