mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove structural_attrs
This commit is contained in:
parent
121ef4df39
commit
1cb798e8ae
|
@ -12,8 +12,6 @@ TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PE
|
|||
|
||||
|
||||
class Lora(fl.Chain):
|
||||
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
|
@ -56,8 +54,6 @@ class Lora(fl.Chain):
|
|||
|
||||
|
||||
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Linear,
|
||||
|
|
|
@ -82,18 +82,6 @@ class ScaledDotProductAttention(Module):
|
|||
|
||||
|
||||
class Attention(Chain):
|
||||
structural_attrs = [
|
||||
"embedding_dim",
|
||||
"num_heads",
|
||||
"heads_dim",
|
||||
"key_embedding_dim",
|
||||
"value_embedding_dim",
|
||||
"inner_dim",
|
||||
"use_bias",
|
||||
"is_causal",
|
||||
"is_optimized",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
|
@ -180,8 +168,6 @@ class SelfAttention(Attention):
|
|||
|
||||
|
||||
class SelfAttention2d(SelfAttention):
|
||||
structural_attrs = Attention.structural_attrs + ["channels"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
|
|
@ -44,8 +44,6 @@ def generate_unique_names(
|
|||
|
||||
|
||||
class UseContext(ContextModule):
|
||||
structural_attrs = ["context", "key", "func"]
|
||||
|
||||
def __init__(self, context: str, key: str) -> None:
|
||||
super().__init__()
|
||||
self.context = context
|
||||
|
@ -74,8 +72,6 @@ class SetContext(ContextModule):
|
|||
#TODO Is there a way to create the context if it doesn't exist?
|
||||
"""
|
||||
|
||||
structural_attrs = ["context", "key", "callback"]
|
||||
|
||||
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
||||
super().__init__()
|
||||
self.context = context
|
||||
|
@ -374,29 +370,16 @@ class Chain(ContextModule):
|
|||
|
||||
Such copies can be adapted without disrupting the base model, but do not
|
||||
require extra GPU memory since the weights are in the leaves and hence not copied.
|
||||
|
||||
This assumes all subclasses define the class variable `structural_attrs` which
|
||||
contains a list of basic attributes set in the constructor. In complicated cases
|
||||
it may be required to overwrite that method.
|
||||
"""
|
||||
if hasattr(self, "_pre_structural_copy"):
|
||||
self._pre_structural_copy()
|
||||
|
||||
modules = [structural_copy(m) for m in self]
|
||||
|
||||
# Instantiate the right subclass, but do not initialize.
|
||||
clone = object.__new__(self.__class__)
|
||||
|
||||
# Copy all basic attributes of the class declared in `structural_attrs`.
|
||||
for k in self.__class__.structural_attrs:
|
||||
setattr(clone, k, getattr(self, k))
|
||||
|
||||
# Call constructor of Chain, which among other things refreshes the context tree.
|
||||
Chain.__init__(clone, *modules)
|
||||
clone = super().structural_copy()
|
||||
clone._provider = ContextProvider()
|
||||
|
||||
for module in modules:
|
||||
if isinstance(module, ContextModule):
|
||||
module._set_parent(clone)
|
||||
clone.append(module=module)
|
||||
|
||||
if hasattr(clone, "_post_structural_copy"):
|
||||
clone._post_structural_copy(self)
|
||||
|
@ -477,7 +460,6 @@ class Breakpoint(ContextModule):
|
|||
|
||||
class Concatenate(Chain):
|
||||
_tag = "CAT"
|
||||
structural_attrs = ["dim"]
|
||||
|
||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||
super().__init__(*modules)
|
||||
|
|
|
@ -17,8 +17,6 @@ class Converter(ContextModule):
|
|||
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
|
||||
"""
|
||||
|
||||
structural_attrs = ["set_device", "set_dtype"]
|
||||
|
||||
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.set_device = set_device
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from inspect import signature, Parameter
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Generator, TypeVar, TypedDict, cast
|
||||
|
||||
from torch import device as Device, dtype as DType
|
||||
|
@ -96,11 +98,6 @@ class ContextModule(Module):
|
|||
_parent: "list[Chain]"
|
||||
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
||||
|
||||
# Contains simple attributes set on the instance by `__init__` in subclasses
|
||||
# and copied by `structural_copy`. Note that is not the case of `device` since
|
||||
# Chain's __init__ takes care of it.
|
||||
structural_attrs: list[str] = []
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, *kwargs)
|
||||
self._parent = []
|
||||
|
@ -139,9 +136,20 @@ class ContextModule(Module):
|
|||
|
||||
def structural_copy(self: TContextModule) -> TContextModule:
|
||||
clone = object.__new__(self.__class__)
|
||||
for k in self.__class__.structural_attrs:
|
||||
|
||||
not_torch_attributes = [
|
||||
key
|
||||
for key, value in self.__dict__.items()
|
||||
if not key.startswith("_")
|
||||
and isinstance(sys.modules.get(type(value).__module__), ModuleType)
|
||||
and "torch" not in sys.modules[type(value).__module__].__name__
|
||||
]
|
||||
|
||||
for k in not_torch_attributes:
|
||||
setattr(clone, k, getattr(self, k))
|
||||
ContextModule.__init__(clone)
|
||||
|
||||
ContextModule.__init__(self=clone)
|
||||
|
||||
return clone
|
||||
|
||||
|
||||
|
|
|
@ -9,8 +9,6 @@ from torch import Tensor, Size, device as Device, dtype as DType
|
|||
|
||||
|
||||
class Downsample(Chain):
|
||||
structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
@ -59,8 +57,6 @@ class Interpolate(Module):
|
|||
|
||||
|
||||
class Upsample(Chain):
|
||||
structural_attrs = ["channels", "upsample_factor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
|
|
@ -3,8 +3,6 @@ import refiners.fluxion.layers as fl
|
|||
|
||||
|
||||
class PositionalEncoder(fl.Chain):
|
||||
structural_attrs = ["max_sequence_length", "embedding_dim"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sequence_length: int,
|
||||
|
@ -33,8 +31,6 @@ class PositionalEncoder(fl.Chain):
|
|||
|
||||
|
||||
class FeedForward(fl.Chain):
|
||||
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
|
|
|
@ -4,8 +4,6 @@ from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
|||
|
||||
|
||||
class ClassEncoder(fl.Chain):
|
||||
structural_attrs = ["embedding_dim"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
|
@ -20,8 +18,6 @@ class ClassEncoder(fl.Chain):
|
|||
|
||||
|
||||
class PatchEncoder(fl.Chain):
|
||||
structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
@ -50,8 +46,6 @@ class PatchEncoder(fl.Chain):
|
|||
|
||||
|
||||
class TransformerLayer(fl.Chain):
|
||||
structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
|
@ -80,8 +74,6 @@ class TransformerLayer(fl.Chain):
|
|||
|
||||
|
||||
class ViTEmbeddings(fl.Chain):
|
||||
structural_attrs = ["image_size", "embedding_dim", "patch_size"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 224,
|
||||
|
@ -90,6 +82,9 @@ class ViTEmbeddings(fl.Chain):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.image_size = image_size
|
||||
self.embedding_dim = embedding_dim
|
||||
self.patch_size = patch_size
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||
|
@ -118,16 +113,6 @@ class ViTEmbeddings(fl.Chain):
|
|||
|
||||
|
||||
class CLIPImageEncoder(fl.Chain):
|
||||
structural_attrs = [
|
||||
"image_size",
|
||||
"embedding_dim",
|
||||
"output_dim",
|
||||
"patch_size",
|
||||
"num_layers",
|
||||
"num_attention_heads",
|
||||
"feedforward_dim",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 224,
|
||||
|
|
|
@ -5,8 +5,6 @@ from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
|||
|
||||
|
||||
class TokenEncoder(fl.Embedding):
|
||||
structural_attrs = ["vocabulary_size", "embedding_dim"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocabulary_size: int,
|
||||
|
@ -25,8 +23,6 @@ class TokenEncoder(fl.Embedding):
|
|||
|
||||
|
||||
class TransformerLayer(fl.Chain):
|
||||
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
|
@ -74,17 +70,6 @@ class TransformerLayer(fl.Chain):
|
|||
|
||||
|
||||
class CLIPTextEncoder(fl.Chain):
|
||||
structural_attrs = [
|
||||
"embedding_dim",
|
||||
"max_sequence_length",
|
||||
"vocabulary_size",
|
||||
"num_layers",
|
||||
"num_attention_heads",
|
||||
"feedforward_dim",
|
||||
"layer_norm_eps",
|
||||
"use_quick_gelu",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 768,
|
||||
|
|
|
@ -17,8 +17,6 @@ from PIL import Image
|
|||
|
||||
|
||||
class Resnet(Sum):
|
||||
structural_attrs = ["in_channels", "out_channels"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
@ -125,8 +123,6 @@ class Encoder(Chain):
|
|||
|
||||
|
||||
class Decoder(Chain):
|
||||
structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"]
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
||||
self.latent_dim: int = 4
|
||||
|
@ -196,7 +192,6 @@ class Decoder(Chain):
|
|||
|
||||
|
||||
class LatentDiffusionAutoencoder(Chain):
|
||||
structural_attrs = ["encoder_scale"]
|
||||
encoder_scale = 0.18125
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -23,8 +23,6 @@ from refiners.fluxion.layers import (
|
|||
|
||||
|
||||
class CrossAttentionBlock(Chain):
|
||||
structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
|
@ -85,8 +83,6 @@ class CrossAttentionBlock(Chain):
|
|||
|
||||
|
||||
class StatefulFlatten(Chain):
|
||||
structural_attrs = ["start_dim", "end_dim"]
|
||||
|
||||
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||
self.start_dim = start_dim
|
||||
self.end_dim = end_dim
|
||||
|
@ -103,19 +99,6 @@ class StatefulFlatten(Chain):
|
|||
|
||||
|
||||
class CrossAttentionBlock2d(Sum):
|
||||
structural_attrs = [
|
||||
"channels",
|
||||
"in_channels",
|
||||
"out_channels",
|
||||
"context_embedding_dim",
|
||||
"num_attention_heads",
|
||||
"num_attention_layers",
|
||||
"num_groups",
|
||||
"context_key",
|
||||
"use_linear_projection",
|
||||
"projection_type",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
|
|
@ -22,8 +22,6 @@ TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
|
|||
|
||||
|
||||
class ImageProjection(fl.Chain):
|
||||
structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clip_image_embedding_dim: int = 1024,
|
||||
|
@ -57,8 +55,6 @@ class InjectionPoint(fl.Chain):
|
|||
|
||||
|
||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||
structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Attention,
|
||||
|
|
|
@ -21,8 +21,6 @@ def compute_sinusoidal_embedding(
|
|||
|
||||
|
||||
class RangeEncoder(fl.Chain):
|
||||
structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sinuosidal_embedding_dim: int,
|
||||
|
@ -45,8 +43,6 @@ class RangeEncoder(fl.Chain):
|
|||
|
||||
|
||||
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||
structural_attrs = ["channels", "embedding_dim", "context_key"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Conv2d,
|
||||
|
|
|
@ -19,8 +19,6 @@ class ConditionEncoder(Chain):
|
|||
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
|
||||
"""
|
||||
|
||||
structural_attrs = ["out_channels"]
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.out_channels = (16, 32, 96, 256)
|
||||
super().__init__(
|
||||
|
@ -72,8 +70,6 @@ class ConditionEncoder(Chain):
|
|||
|
||||
|
||||
class Controlnet(Passthrough):
|
||||
structural_attrs = ["scale", "name"]
|
||||
|
||||
def __init__(
|
||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
|
|
|
@ -24,8 +24,6 @@ class TimestepEncoder(fl.Passthrough):
|
|||
|
||||
|
||||
class ResidualBlock(fl.Sum):
|
||||
structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
@ -92,8 +90,6 @@ class CLIPLCrossAttention(CrossAttentionBlock2d):
|
|||
|
||||
|
||||
class DownBlocks(fl.Chain):
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
@ -211,8 +207,6 @@ class MiddleBlock(fl.Chain):
|
|||
|
||||
|
||||
class ResidualAccumulator(fl.Passthrough):
|
||||
structural_attrs = ["n"]
|
||||
|
||||
def __init__(self, n: int) -> None:
|
||||
self.n = n
|
||||
|
||||
|
@ -228,8 +222,6 @@ class ResidualAccumulator(fl.Passthrough):
|
|||
|
||||
|
||||
class ResidualConcatenator(fl.Chain):
|
||||
structural_attrs = ["n"]
|
||||
|
||||
def __init__(self, n: int) -> None:
|
||||
self.n = n
|
||||
|
||||
|
@ -243,8 +235,6 @@ class ResidualConcatenator(fl.Chain):
|
|||
|
||||
|
||||
class SD1UNet(fl.Chain):
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.in_channels = in_channels
|
||||
super().__init__(
|
||||
|
|
|
@ -16,8 +16,6 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
|
|||
|
||||
|
||||
class TextTimeEmbedding(fl.Chain):
|
||||
structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"]
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.timestep_embedding_dim = 1280
|
||||
self.time_ids_embedding_dim = 256
|
||||
|
@ -54,8 +52,6 @@ class TextTimeEmbedding(fl.Chain):
|
|||
|
||||
|
||||
class TimestepEncoder(fl.Passthrough):
|
||||
structural_attrs = ["timestep_embedding_dim"]
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.timestep_embedding_dim = 1280
|
||||
super().__init__(
|
||||
|
@ -98,8 +94,6 @@ class SDXLCrossAttention(CrossAttentionBlock2d):
|
|||
|
||||
|
||||
class DownBlocks(fl.Chain):
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.in_channels = in_channels
|
||||
|
||||
|
@ -158,8 +152,6 @@ class DownBlocks(fl.Chain):
|
|||
|
||||
|
||||
class UpBlocks(fl.Chain):
|
||||
structural_attrs = []
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
first_blocks = [
|
||||
fl.Chain(
|
||||
|
@ -225,8 +217,6 @@ class UpBlocks(fl.Chain):
|
|||
|
||||
|
||||
class MiddleBlock(fl.Chain):
|
||||
structural_attrs = []
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__(
|
||||
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||
|
@ -238,8 +228,6 @@ class MiddleBlock(fl.Chain):
|
|||
|
||||
|
||||
class OutputBlock(fl.Chain):
|
||||
structural_attrs = []
|
||||
|
||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
super().__init__(
|
||||
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
|
||||
|
@ -249,8 +237,6 @@ class OutputBlock(fl.Chain):
|
|||
|
||||
|
||||
class SDXLUNet(fl.Chain):
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
self.in_channels = in_channels
|
||||
super().__init__(
|
||||
|
|
Loading…
Reference in a new issue