mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
remove structural_attrs
This commit is contained in:
parent
121ef4df39
commit
1cb798e8ae
|
@ -12,8 +12,6 @@ TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PE
|
||||||
|
|
||||||
|
|
||||||
class Lora(fl.Chain):
|
class Lora(fl.Chain):
|
||||||
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
|
@ -56,8 +54,6 @@ class Lora(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||||
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: fl.Linear,
|
target: fl.Linear,
|
||||||
|
|
|
@ -82,18 +82,6 @@ class ScaledDotProductAttention(Module):
|
||||||
|
|
||||||
|
|
||||||
class Attention(Chain):
|
class Attention(Chain):
|
||||||
structural_attrs = [
|
|
||||||
"embedding_dim",
|
|
||||||
"num_heads",
|
|
||||||
"heads_dim",
|
|
||||||
"key_embedding_dim",
|
|
||||||
"value_embedding_dim",
|
|
||||||
"inner_dim",
|
|
||||||
"use_bias",
|
|
||||||
"is_causal",
|
|
||||||
"is_optimized",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -180,8 +168,6 @@ class SelfAttention(Attention):
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention2d(SelfAttention):
|
class SelfAttention2d(SelfAttention):
|
||||||
structural_attrs = Attention.structural_attrs + ["channels"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
|
|
@ -44,8 +44,6 @@ def generate_unique_names(
|
||||||
|
|
||||||
|
|
||||||
class UseContext(ContextModule):
|
class UseContext(ContextModule):
|
||||||
structural_attrs = ["context", "key", "func"]
|
|
||||||
|
|
||||||
def __init__(self, context: str, key: str) -> None:
|
def __init__(self, context: str, key: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.context = context
|
self.context = context
|
||||||
|
@ -74,8 +72,6 @@ class SetContext(ContextModule):
|
||||||
#TODO Is there a way to create the context if it doesn't exist?
|
#TODO Is there a way to create the context if it doesn't exist?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
structural_attrs = ["context", "key", "callback"]
|
|
||||||
|
|
||||||
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.context = context
|
self.context = context
|
||||||
|
@ -374,29 +370,16 @@ class Chain(ContextModule):
|
||||||
|
|
||||||
Such copies can be adapted without disrupting the base model, but do not
|
Such copies can be adapted without disrupting the base model, but do not
|
||||||
require extra GPU memory since the weights are in the leaves and hence not copied.
|
require extra GPU memory since the weights are in the leaves and hence not copied.
|
||||||
|
|
||||||
This assumes all subclasses define the class variable `structural_attrs` which
|
|
||||||
contains a list of basic attributes set in the constructor. In complicated cases
|
|
||||||
it may be required to overwrite that method.
|
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "_pre_structural_copy"):
|
if hasattr(self, "_pre_structural_copy"):
|
||||||
self._pre_structural_copy()
|
self._pre_structural_copy()
|
||||||
|
|
||||||
modules = [structural_copy(m) for m in self]
|
modules = [structural_copy(m) for m in self]
|
||||||
|
clone = super().structural_copy()
|
||||||
# Instantiate the right subclass, but do not initialize.
|
clone._provider = ContextProvider()
|
||||||
clone = object.__new__(self.__class__)
|
|
||||||
|
|
||||||
# Copy all basic attributes of the class declared in `structural_attrs`.
|
|
||||||
for k in self.__class__.structural_attrs:
|
|
||||||
setattr(clone, k, getattr(self, k))
|
|
||||||
|
|
||||||
# Call constructor of Chain, which among other things refreshes the context tree.
|
|
||||||
Chain.__init__(clone, *modules)
|
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
if isinstance(module, ContextModule):
|
clone.append(module=module)
|
||||||
module._set_parent(clone)
|
|
||||||
|
|
||||||
if hasattr(clone, "_post_structural_copy"):
|
if hasattr(clone, "_post_structural_copy"):
|
||||||
clone._post_structural_copy(self)
|
clone._post_structural_copy(self)
|
||||||
|
@ -477,7 +460,6 @@ class Breakpoint(ContextModule):
|
||||||
|
|
||||||
class Concatenate(Chain):
|
class Concatenate(Chain):
|
||||||
_tag = "CAT"
|
_tag = "CAT"
|
||||||
structural_attrs = ["dim"]
|
|
||||||
|
|
||||||
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||||
super().__init__(*modules)
|
super().__init__(*modules)
|
||||||
|
|
|
@ -17,8 +17,6 @@ class Converter(ContextModule):
|
||||||
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
|
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
structural_attrs = ["set_device", "set_dtype"]
|
|
||||||
|
|
||||||
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
|
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.set_device = set_device
|
self.set_device = set_device
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from inspect import signature, Parameter
|
from inspect import signature, Parameter
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
from typing import Any, Generator, TypeVar, TypedDict, cast
|
from typing import Any, Generator, TypeVar, TypedDict, cast
|
||||||
|
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
|
@ -96,11 +98,6 @@ class ContextModule(Module):
|
||||||
_parent: "list[Chain]"
|
_parent: "list[Chain]"
|
||||||
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
||||||
|
|
||||||
# Contains simple attributes set on the instance by `__init__` in subclasses
|
|
||||||
# and copied by `structural_copy`. Note that is not the case of `device` since
|
|
||||||
# Chain's __init__ takes care of it.
|
|
||||||
structural_attrs: list[str] = []
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, *kwargs)
|
super().__init__(*args, *kwargs)
|
||||||
self._parent = []
|
self._parent = []
|
||||||
|
@ -139,9 +136,20 @@ class ContextModule(Module):
|
||||||
|
|
||||||
def structural_copy(self: TContextModule) -> TContextModule:
|
def structural_copy(self: TContextModule) -> TContextModule:
|
||||||
clone = object.__new__(self.__class__)
|
clone = object.__new__(self.__class__)
|
||||||
for k in self.__class__.structural_attrs:
|
|
||||||
|
not_torch_attributes = [
|
||||||
|
key
|
||||||
|
for key, value in self.__dict__.items()
|
||||||
|
if not key.startswith("_")
|
||||||
|
and isinstance(sys.modules.get(type(value).__module__), ModuleType)
|
||||||
|
and "torch" not in sys.modules[type(value).__module__].__name__
|
||||||
|
]
|
||||||
|
|
||||||
|
for k in not_torch_attributes:
|
||||||
setattr(clone, k, getattr(self, k))
|
setattr(clone, k, getattr(self, k))
|
||||||
ContextModule.__init__(clone)
|
|
||||||
|
ContextModule.__init__(self=clone)
|
||||||
|
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,6 @@ from torch import Tensor, Size, device as Device, dtype as DType
|
||||||
|
|
||||||
|
|
||||||
class Downsample(Chain):
|
class Downsample(Chain):
|
||||||
structural_attrs = ["channels", "in_channels", "out_channels", "scale_factor", "padding"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
@ -59,8 +57,6 @@ class Interpolate(Module):
|
||||||
|
|
||||||
|
|
||||||
class Upsample(Chain):
|
class Upsample(Chain):
|
||||||
structural_attrs = ["channels", "upsample_factor"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
|
|
@ -3,8 +3,6 @@ import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoder(fl.Chain):
|
class PositionalEncoder(fl.Chain):
|
||||||
structural_attrs = ["max_sequence_length", "embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_sequence_length: int,
|
max_sequence_length: int,
|
||||||
|
@ -33,8 +31,6 @@ class PositionalEncoder(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(fl.Chain):
|
class FeedForward(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim", "feedforward_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
|
|
@ -4,8 +4,6 @@ from refiners.foundationals.clip.common import PositionalEncoder, FeedForward
|
||||||
|
|
||||||
|
|
||||||
class ClassEncoder(fl.Chain):
|
class ClassEncoder(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -20,8 +18,6 @@ class ClassEncoder(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class PatchEncoder(fl.Chain):
|
class PatchEncoder(fl.Chain):
|
||||||
structural_attrs = ["in_channels", "out_channels", "patch_size", "use_bias"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
|
@ -50,8 +46,6 @@ class PatchEncoder(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class TransformerLayer(fl.Chain):
|
class TransformerLayer(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim", "feedforward_dim", "num_attention_heads", "layer_norm_eps"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 768,
|
embedding_dim: int = 768,
|
||||||
|
@ -80,8 +74,6 @@ class TransformerLayer(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class ViTEmbeddings(fl.Chain):
|
class ViTEmbeddings(fl.Chain):
|
||||||
structural_attrs = ["image_size", "embedding_dim", "patch_size"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_size: int = 224,
|
image_size: int = 224,
|
||||||
|
@ -90,6 +82,9 @@ class ViTEmbeddings(fl.Chain):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.image_size = image_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Concatenate(
|
fl.Concatenate(
|
||||||
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
ClassEncoder(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
|
@ -118,16 +113,6 @@ class ViTEmbeddings(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class CLIPImageEncoder(fl.Chain):
|
class CLIPImageEncoder(fl.Chain):
|
||||||
structural_attrs = [
|
|
||||||
"image_size",
|
|
||||||
"embedding_dim",
|
|
||||||
"output_dim",
|
|
||||||
"patch_size",
|
|
||||||
"num_layers",
|
|
||||||
"num_attention_heads",
|
|
||||||
"feedforward_dim",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_size: int = 224,
|
image_size: int = 224,
|
||||||
|
|
|
@ -5,8 +5,6 @@ from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
class TokenEncoder(fl.Embedding):
|
class TokenEncoder(fl.Embedding):
|
||||||
structural_attrs = ["vocabulary_size", "embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocabulary_size: int,
|
vocabulary_size: int,
|
||||||
|
@ -25,8 +23,6 @@ class TokenEncoder(fl.Embedding):
|
||||||
|
|
||||||
|
|
||||||
class TransformerLayer(fl.Chain):
|
class TransformerLayer(fl.Chain):
|
||||||
structural_attrs = ["embedding_dim", "num_attention_heads", "feedforward_dim", "layer_norm_eps"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -74,17 +70,6 @@ class TransformerLayer(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(fl.Chain):
|
class CLIPTextEncoder(fl.Chain):
|
||||||
structural_attrs = [
|
|
||||||
"embedding_dim",
|
|
||||||
"max_sequence_length",
|
|
||||||
"vocabulary_size",
|
|
||||||
"num_layers",
|
|
||||||
"num_attention_heads",
|
|
||||||
"feedforward_dim",
|
|
||||||
"layer_norm_eps",
|
|
||||||
"use_quick_gelu",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 768,
|
embedding_dim: int = 768,
|
||||||
|
|
|
@ -17,8 +17,6 @@ from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class Resnet(Sum):
|
class Resnet(Sum):
|
||||||
structural_attrs = ["in_channels", "out_channels"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
|
@ -125,8 +123,6 @@ class Encoder(Chain):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(Chain):
|
class Decoder(Chain):
|
||||||
structural_attrs = ["resnet_sizes", "latent_dim", "output_channels"]
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
||||||
self.latent_dim: int = 4
|
self.latent_dim: int = 4
|
||||||
|
@ -196,7 +192,6 @@ class Decoder(Chain):
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionAutoencoder(Chain):
|
class LatentDiffusionAutoencoder(Chain):
|
||||||
structural_attrs = ["encoder_scale"]
|
|
||||||
encoder_scale = 0.18125
|
encoder_scale = 0.18125
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -23,8 +23,6 @@ from refiners.fluxion.layers import (
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionBlock(Chain):
|
class CrossAttentionBlock(Chain):
|
||||||
structural_attrs = ["embedding_dim", "context_embedding_dim", "context", "context_key", "num_heads", "use_bias"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
|
@ -85,8 +83,6 @@ class CrossAttentionBlock(Chain):
|
||||||
|
|
||||||
|
|
||||||
class StatefulFlatten(Chain):
|
class StatefulFlatten(Chain):
|
||||||
structural_attrs = ["start_dim", "end_dim"]
|
|
||||||
|
|
||||||
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
|
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
self.start_dim = start_dim
|
self.start_dim = start_dim
|
||||||
self.end_dim = end_dim
|
self.end_dim = end_dim
|
||||||
|
@ -103,19 +99,6 @@ class StatefulFlatten(Chain):
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionBlock2d(Sum):
|
class CrossAttentionBlock2d(Sum):
|
||||||
structural_attrs = [
|
|
||||||
"channels",
|
|
||||||
"in_channels",
|
|
||||||
"out_channels",
|
|
||||||
"context_embedding_dim",
|
|
||||||
"num_attention_heads",
|
|
||||||
"num_attention_layers",
|
|
||||||
"num_groups",
|
|
||||||
"context_key",
|
|
||||||
"use_linear_projection",
|
|
||||||
"projection_type",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
|
|
@ -22,8 +22,6 @@ TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
class ImageProjection(fl.Chain):
|
class ImageProjection(fl.Chain):
|
||||||
structural_attrs = ["clip_image_embedding_dim", "clip_text_embedding_dim", "sequence_length"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
clip_image_embedding_dim: int = 1024,
|
clip_image_embedding_dim: int = 1024,
|
||||||
|
@ -57,8 +55,6 @@ class InjectionPoint(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
structural_attrs = ["text_sequence_length", "image_sequence_length", "scale"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: fl.Attention,
|
target: fl.Attention,
|
||||||
|
|
|
@ -21,8 +21,6 @@ def compute_sinusoidal_embedding(
|
||||||
|
|
||||||
|
|
||||||
class RangeEncoder(fl.Chain):
|
class RangeEncoder(fl.Chain):
|
||||||
structural_attrs = ["sinuosidal_embedding_dim", "embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sinuosidal_embedding_dim: int,
|
sinuosidal_embedding_dim: int,
|
||||||
|
@ -45,8 +43,6 @@ class RangeEncoder(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||||
structural_attrs = ["channels", "embedding_dim", "context_key"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: fl.Conv2d,
|
target: fl.Conv2d,
|
||||||
|
|
|
@ -19,8 +19,6 @@ class ConditionEncoder(Chain):
|
||||||
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
|
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
structural_attrs = ["out_channels"]
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.out_channels = (16, 32, 96, 256)
|
self.out_channels = (16, 32, 96, 256)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -72,8 +70,6 @@ class ConditionEncoder(Chain):
|
||||||
|
|
||||||
|
|
||||||
class Controlnet(Passthrough):
|
class Controlnet(Passthrough):
|
||||||
structural_attrs = ["scale", "name"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -24,8 +24,6 @@ class TimestepEncoder(fl.Passthrough):
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(fl.Sum):
|
class ResidualBlock(fl.Sum):
|
||||||
structural_attrs = ["in_channels", "out_channels", "num_groups", "eps"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
|
@ -92,8 +90,6 @@ class CLIPLCrossAttention(CrossAttentionBlock2d):
|
||||||
|
|
||||||
|
|
||||||
class DownBlocks(fl.Chain):
|
class DownBlocks(fl.Chain):
|
||||||
structural_attrs = ["in_channels"]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
|
@ -211,8 +207,6 @@ class MiddleBlock(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class ResidualAccumulator(fl.Passthrough):
|
class ResidualAccumulator(fl.Passthrough):
|
||||||
structural_attrs = ["n"]
|
|
||||||
|
|
||||||
def __init__(self, n: int) -> None:
|
def __init__(self, n: int) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
|
|
||||||
|
@ -228,8 +222,6 @@ class ResidualAccumulator(fl.Passthrough):
|
||||||
|
|
||||||
|
|
||||||
class ResidualConcatenator(fl.Chain):
|
class ResidualConcatenator(fl.Chain):
|
||||||
structural_attrs = ["n"]
|
|
||||||
|
|
||||||
def __init__(self, n: int) -> None:
|
def __init__(self, n: int) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
|
|
||||||
|
@ -243,8 +235,6 @@ class ResidualConcatenator(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class SD1UNet(fl.Chain):
|
class SD1UNet(fl.Chain):
|
||||||
structural_attrs = ["in_channels"]
|
|
||||||
|
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
|
@ -16,8 +16,6 @@ from refiners.foundationals.latent_diffusion.range_adapter import (
|
||||||
|
|
||||||
|
|
||||||
class TextTimeEmbedding(fl.Chain):
|
class TextTimeEmbedding(fl.Chain):
|
||||||
structural_attrs = ["timestep_embedding_dim", "time_ids_embedding_dim", "text_time_embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.timestep_embedding_dim = 1280
|
self.timestep_embedding_dim = 1280
|
||||||
self.time_ids_embedding_dim = 256
|
self.time_ids_embedding_dim = 256
|
||||||
|
@ -54,8 +52,6 @@ class TextTimeEmbedding(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(fl.Passthrough):
|
class TimestepEncoder(fl.Passthrough):
|
||||||
structural_attrs = ["timestep_embedding_dim"]
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.timestep_embedding_dim = 1280
|
self.timestep_embedding_dim = 1280
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -98,8 +94,6 @@ class SDXLCrossAttention(CrossAttentionBlock2d):
|
||||||
|
|
||||||
|
|
||||||
class DownBlocks(fl.Chain):
|
class DownBlocks(fl.Chain):
|
||||||
structural_attrs = ["in_channels"]
|
|
||||||
|
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
@ -158,8 +152,6 @@ class DownBlocks(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class UpBlocks(fl.Chain):
|
class UpBlocks(fl.Chain):
|
||||||
structural_attrs = []
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
first_blocks = [
|
first_blocks = [
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
|
@ -225,8 +217,6 @@ class UpBlocks(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class MiddleBlock(fl.Chain):
|
class MiddleBlock(fl.Chain):
|
||||||
structural_attrs = []
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
@ -238,8 +228,6 @@ class MiddleBlock(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class OutputBlock(fl.Chain):
|
class OutputBlock(fl.Chain):
|
||||||
structural_attrs = []
|
|
||||||
|
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
|
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
|
||||||
|
@ -249,8 +237,6 @@ class OutputBlock(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class SDXLUNet(fl.Chain):
|
class SDXLUNet(fl.Chain):
|
||||||
structural_attrs = ["in_channels"]
|
|
||||||
|
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
Loading…
Reference in a new issue