From 1cb798e8ae9be88c9b4ee27d73ae6c1c501c91ea Mon Sep 17 00:00:00 2001 From: Benjamin Trom Date: Thu, 14 Sep 2023 14:34:31 +0200 Subject: [PATCH] remove structural_attrs --- src/refiners/fluxion/adapters/lora.py | 4 ---- src/refiners/fluxion/layers/attentions.py | 14 ----------- src/refiners/fluxion/layers/chain.py | 24 +++---------------- src/refiners/fluxion/layers/converter.py | 2 -- src/refiners/fluxion/layers/module.py | 22 +++++++++++------ src/refiners/fluxion/layers/sampling.py | 4 ---- src/refiners/foundationals/clip/common.py | 4 ---- .../foundationals/clip/image_encoder.py | 21 +++------------- .../foundationals/clip/text_encoder.py | 15 ------------ .../latent_diffusion/auto_encoder.py | 5 ---- .../latent_diffusion/cross_attention.py | 17 ------------- .../latent_diffusion/image_prompt.py | 4 ---- .../latent_diffusion/range_adapter.py | 4 ---- .../stable_diffusion_1/controlnet.py | 4 ---- .../stable_diffusion_1/unet.py | 10 -------- .../stable_diffusion_xl/unet.py | 14 ----------- 16 files changed, 21 insertions(+), 147 deletions(-) diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index fe2b067..fc0d374 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -12,8 +12,6 @@ TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PE class Lora(fl.Chain): - structural_attrs = ["in_features", "out_features", "rank", "scale"] - def __init__( self, in_features: int, @@ -56,8 +54,6 @@ class Lora(fl.Chain): class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]): - structural_attrs = ["in_features", "out_features", "rank", "scale"] - def __init__( self, target: fl.Linear, diff --git a/src/refiners/fluxion/layers/attentions.py b/src/refiners/fluxion/layers/attentions.py index 3d356b0..f7f2200 100644 --- a/src/refiners/fluxion/layers/attentions.py +++ b/src/refiners/fluxion/layers/attentions.py @@ -82,18 +82,6 @@ class ScaledDotProductAttention(Module): class Attention(Chain): - structural_attrs = [ - "embedding_dim", - "num_heads", - "heads_dim", - "key_embedding_dim", - "value_embedding_dim", - "inner_dim", - "use_bias", - "is_causal", - "is_optimized", - ] - def __init__( self, embedding_dim: int, @@ -180,8 +168,6 @@ class SelfAttention(Attention): class SelfAttention2d(SelfAttention): - structural_attrs = Attention.structural_attrs + ["channels"] - def __init__( self, channels: int, diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 374cf18..3839b9b 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -44,8 +44,6 @@ def generate_unique_names( class UseContext(ContextModule): - structural_attrs = ["context", "key", "func"] - def __init__(self, context: str, key: str) -> None: super().__init__() self.context = context @@ -74,8 +72,6 @@ class SetContext(ContextModule): #TODO Is there a way to create the context if it doesn't exist? """ - structural_attrs = ["context", "key", "callback"] - def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None: super().__init__() self.context = context @@ -374,29 +370,16 @@ class Chain(ContextModule): Such copies can be adapted without disrupting the base model, but do not require extra GPU memory since the weights are in the leaves and hence not copied. - - This assumes all subclasses define the class variable `structural_attrs` which - contains a list of basic attributes set in the constructor. In complicated cases - it may be required to overwrite that method. """ if hasattr(self, "_pre_structural_copy"): self._pre_structural_copy() modules = [structural_copy(m) for m in self] - - # Instantiate the right subclass, but do not initialize. - clone = object.__new__(self.__class__) - - # Copy all basic attributes of the class declared in `structural_attrs`. - for k in self.__class__.structural_attrs: - setattr(clone, k, getattr(self, k)) - - # Call constructor of Chain, which among other things refreshes the context tree. - Chain.__init__(clone, *modules) + clone = super().structural_copy() + clone._provider = ContextProvider() for module in modules: - if isinstance(module, ContextModule): - module._set_parent(clone) + clone.append(module=module) if hasattr(clone, "_post_structural_copy"): clone._post_structural_copy(self) @@ -477,7 +460,6 @@ class Breakpoint(ContextModule): class Concatenate(Chain): _tag = "CAT" - structural_attrs = ["dim"] def __init__(self, *modules: Module, dim: int = 0) -> None: super().__init__(*modules) diff --git a/src/refiners/fluxion/layers/converter.py b/src/refiners/fluxion/layers/converter.py index 653bb74..268bf5d 100644 --- a/src/refiners/fluxion/layers/converter.py +++ b/src/refiners/fluxion/layers/converter.py @@ -17,8 +17,6 @@ class Converter(ContextModule): Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True. """ - structural_attrs = ["set_device", "set_dtype"] - def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None: super().__init__() self.set_device = set_device diff --git a/src/refiners/fluxion/layers/module.py b/src/refiners/fluxion/layers/module.py index 5a4ea2d..3e5cf1a 100644 --- a/src/refiners/fluxion/layers/module.py +++ b/src/refiners/fluxion/layers/module.py @@ -1,5 +1,7 @@ from inspect import signature, Parameter +import sys from pathlib import Path +from types import ModuleType from typing import Any, Generator, TypeVar, TypedDict, cast from torch import device as Device, dtype as DType @@ -96,11 +98,6 @@ class ContextModule(Module): _parent: "list[Chain]" _can_refresh_parent: bool = True # see usage in Adapter and Chain - # Contains simple attributes set on the instance by `__init__` in subclasses - # and copied by `structural_copy`. Note that is not the case of `device` since - # Chain's __init__ takes care of it. - structural_attrs: list[str] = [] - def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, *kwargs) self._parent = [] @@ -139,9 +136,20 @@ class ContextModule(Module): def structural_copy(self: TContextModule) -> TContextModule: clone = object.__new__(self.__class__) - for k in self.__class__.structural_attrs: + + not_torch_attributes = [ + key + for key, value in self.__dict__.items() + if not key.startswith("_") + and isinstance(sys.modules.get(type(value).__module__), ModuleType) + and "torch" not in sys.modules[type(value).__module__].__name__ + ] + + for k in not_torch_attributes: setattr(clone, k, getattr(self, k)) - ContextModule.__init__(clone) + + ContextModule.__init__(self=clone) + return clone diff --git a/src/refiners/fluxion/layers/sampling.py b/src/refiners/fluxion/layers/sampling.py index 29420c33..e5e75ff 100644 --- a/src/refiners/fluxion/layers/sampling.py +++ b/src/refiners/fluxion/layers/sampling.py @@ -9,8 +9,6 @@ from torch import Tensor, Size, device as Device, dtype as DType class Downsample(Chain): - structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"] - def __init__( self, channels: int, @@ -59,8 +57,6 @@ class Interpolate(Module): class Upsample(Chain): - structural_attrs = ["channels", "upsample_factor"] - def __init__( self, channels: int, diff --git a/src/refiners/foundationals/clip/common.py b/src/refiners/foundationals/clip/common.py index 399e87d..77a9a94 100644 --- a/src/refiners/foundationals/clip/common.py +++ b/src/refiners/foundationals/clip/common.py @@ -3,8 +3,6 @@ import refiners.fluxion.layers as fl class PositionalEncoder(fl.Chain): - structural_attrs = ["max_sequence_length", "embedding_dim"] - def __init__( self, max_sequence_length: int, @@ -33,8 +31,6 @@ class PositionalEncoder(fl.Chain): class FeedForward(fl.Chain): - structural_attrs = ["embedding_dim", "feedforward_dim"] - def __init__( self, embedding_dim: int, diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py index 590122a..910cd31 100644 --- a/src/refiners/foundationals/clip/image_encoder.py +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -4,8 +4,6 @@ from refiners.foundationals.clip.common import PositionalEncoder, FeedForward class ClassEncoder(fl.Chain): - structural_attrs = ["embedding_dim"] - def __init__( self, embedding_dim: int, @@ -20,8 +18,6 @@ class ClassEncoder(fl.Chain): class PatchEncoder(fl.Chain): - structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"] - def __init__( self, in_channels: int, @@ -50,8 +46,6 @@ class PatchEncoder(fl.Chain): class TransformerLayer(fl.Chain): - structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"] - def __init__( self, embedding_dim: int = 768, @@ -80,8 +74,6 @@ class TransformerLayer(fl.Chain): class ViTEmbeddings(fl.Chain): - structural_attrs = ["image_size", "embedding_dim", "patch_size"] - def __init__( self, image_size: int = 224, @@ -90,6 +82,9 @@ class ViTEmbeddings(fl.Chain): device: Device | str | None = None, dtype: DType | None = None, ) -> None: + self.image_size = image_size + self.embedding_dim = embedding_dim + self.patch_size = patch_size super().__init__( fl.Concatenate( ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype), @@ -118,16 +113,6 @@ class ViTEmbeddings(fl.Chain): class CLIPImageEncoder(fl.Chain): - structural_attrs = [ - "image_size", - "embedding_dim", - "output_dim", - "patch_size", - "num_layers", - "num_attention_heads", - "feedforward_dim", - ] - def __init__( self, image_size: int = 224, diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 638e343..5a532c6 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -5,8 +5,6 @@ from refiners.foundationals.clip.tokenizer import CLIPTokenizer class TokenEncoder(fl.Embedding): - structural_attrs = ["vocabulary_size", "embedding_dim"] - def __init__( self, vocabulary_size: int, @@ -25,8 +23,6 @@ class TokenEncoder(fl.Embedding): class TransformerLayer(fl.Chain): - structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"] - def __init__( self, embedding_dim: int, @@ -74,17 +70,6 @@ class TransformerLayer(fl.Chain): class CLIPTextEncoder(fl.Chain): - structural_attrs = [ - "embedding_dim", - "max_sequence_length", - "vocabulary_size", - "num_layers", - "num_attention_heads", - "feedforward_dim", - "layer_norm_eps", - "use_quick_gelu", - ] - def __init__( self, embedding_dim: int = 768, diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 59114c6..cdd2018 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -17,8 +17,6 @@ from PIL import Image class Resnet(Sum): - structural_attrs = ["in_channels", "out_channels"] - def __init__( self, in_channels: int, @@ -125,8 +123,6 @@ class Encoder(Chain): class Decoder(Chain): - structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: self.resnet_sizes: list[int] = [128, 256, 512, 512, 512] self.latent_dim: int = 4 @@ -196,7 +192,6 @@ class Decoder(Chain): class LatentDiffusionAutoencoder(Chain): - structural_attrs = ["encoder_scale"] encoder_scale = 0.18125 def __init__( diff --git a/src/refiners/foundationals/latent_diffusion/cross_attention.py b/src/refiners/foundationals/latent_diffusion/cross_attention.py index c973f86..ca16bdb 100644 --- a/src/refiners/foundationals/latent_diffusion/cross_attention.py +++ b/src/refiners/foundationals/latent_diffusion/cross_attention.py @@ -23,8 +23,6 @@ from refiners.fluxion.layers import ( class CrossAttentionBlock(Chain): - structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"] - def __init__( self, embedding_dim: int, @@ -85,8 +83,6 @@ class CrossAttentionBlock(Chain): class StatefulFlatten(Chain): - structural_attrs = ["start_dim", "end_dim"] - def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None: self.start_dim = start_dim self.end_dim = end_dim @@ -103,19 +99,6 @@ class StatefulFlatten(Chain): class CrossAttentionBlock2d(Sum): - structural_attrs = [ - "channels", - "in_channels", - "out_channels", - "context_embedding_dim", - "num_attention_heads", - "num_attention_layers", - "num_groups", - "context_key", - "use_linear_projection", - "projection_type", - ] - def __init__( self, channels: int, diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 756f71f..2250c2e 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -22,8 +22,6 @@ TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673) class ImageProjection(fl.Chain): - structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"] - def __init__( self, clip_image_embedding_dim: int = 1024, @@ -57,8 +55,6 @@ class InjectionPoint(fl.Chain): class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): - structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"] - def __init__( self, target: fl.Attention, diff --git a/src/refiners/foundationals/latent_diffusion/range_adapter.py b/src/refiners/foundationals/latent_diffusion/range_adapter.py index a9a88a7..5b5fe97 100644 --- a/src/refiners/foundationals/latent_diffusion/range_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/range_adapter.py @@ -21,8 +21,6 @@ def compute_sinusoidal_embedding( class RangeEncoder(fl.Chain): - structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"] - def __init__( self, sinuosidal_embedding_dim: int, @@ -45,8 +43,6 @@ class RangeEncoder(fl.Chain): class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]): - structural_attrs = ["channels", "embedding_dim", "context_key"] - def __init__( self, target: fl.Conv2d, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 9a94b27..24bd58b 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -19,8 +19,6 @@ class ConditionEncoder(Chain): Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor. """ - structural_attrs = ["out_channels"] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: self.out_channels = (16, 32, 96, 256) super().__init__( @@ -72,8 +70,6 @@ class ConditionEncoder(Chain): class Controlnet(Passthrough): - structural_attrs = ["scale", "name"] - def __init__( self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None ) -> None: diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 2970bf2..e2a8f89 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -24,8 +24,6 @@ class TimestepEncoder(fl.Passthrough): class ResidualBlock(fl.Sum): - structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"] - def __init__( self, in_channels: int, @@ -92,8 +90,6 @@ class CLIPLCrossAttention(CrossAttentionBlock2d): class DownBlocks(fl.Chain): - structural_attrs = ["in_channels"] - def __init__( self, in_channels: int, @@ -211,8 +207,6 @@ class MiddleBlock(fl.Chain): class ResidualAccumulator(fl.Passthrough): - structural_attrs = ["n"] - def __init__(self, n: int) -> None: self.n = n @@ -228,8 +222,6 @@ class ResidualAccumulator(fl.Passthrough): class ResidualConcatenator(fl.Chain): - structural_attrs = ["n"] - def __init__(self, n: int) -> None: self.n = n @@ -243,8 +235,6 @@ class ResidualConcatenator(fl.Chain): class SD1UNet(fl.Chain): - structural_attrs = ["in_channels"] - def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: self.in_channels = in_channels super().__init__( diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index a35fd9b..4edf8c1 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -16,8 +16,6 @@ from refiners.foundationals.latent_diffusion.range_adapter import ( class TextTimeEmbedding(fl.Chain): - structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: self.timestep_embedding_dim = 1280 self.time_ids_embedding_dim = 256 @@ -54,8 +52,6 @@ class TextTimeEmbedding(fl.Chain): class TimestepEncoder(fl.Passthrough): - structural_attrs = ["timestep_embedding_dim"] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: self.timestep_embedding_dim = 1280 super().__init__( @@ -98,8 +94,6 @@ class SDXLCrossAttention(CrossAttentionBlock2d): class DownBlocks(fl.Chain): - structural_attrs = ["in_channels"] - def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: self.in_channels = in_channels @@ -158,8 +152,6 @@ class DownBlocks(fl.Chain): class UpBlocks(fl.Chain): - structural_attrs = [] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: first_blocks = [ fl.Chain( @@ -225,8 +217,6 @@ class UpBlocks(fl.Chain): class MiddleBlock(fl.Chain): - structural_attrs = [] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__( ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype), @@ -238,8 +228,6 @@ class MiddleBlock(fl.Chain): class OutputBlock(fl.Chain): - structural_attrs = [] - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: super().__init__( fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype), @@ -249,8 +237,6 @@ class OutputBlock(fl.Chain): class SDXLUNet(fl.Chain): - structural_attrs = ["in_channels"] - def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: self.in_channels = in_channels super().__init__(