diff --git a/src/refiners/foundationals/clip/text_encoder.py b/src/refiners/foundationals/clip/text_encoder.py index 9cb2fb3..ef1e455 100644 --- a/src/refiners/foundationals/clip/text_encoder.py +++ b/src/refiners/foundationals/clip/text_encoder.py @@ -1,7 +1,6 @@ from torch import Tensor, arange, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.foundationals.clip.tokenizer import CLIPTokenizer -import refiners.foundationals.latent_diffusion.model as ldm class TokenEncoder(fl.Embedding): @@ -122,7 +121,7 @@ class TransformerLayer(fl.Chain): ) -class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface): +class CLIPTextEncoder(fl.Chain): structural_attrs = [ "embedding_dim", "max_sequence_length", diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index c90886b..83b6789 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Protocol, TypeVar +from typing import TypeVar from torch import Tensor, device as Device, dtype as DType from PIL import Image import torch @@ -11,31 +11,15 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul T = TypeVar("T", bound="fl.Module") -class UNetInterface(Protocol): - def set_timestep(self, timestep: Tensor) -> None: - ... - - def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None: - ... - - def __call__(self, x: Tensor) -> Tensor: - ... - - -class TextEncoderInterface(Protocol): - def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]: - ... - - TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") class LatentDiffusionModel(fl.Module, ABC): def __init__( self, - unet: UNetInterface, + unet: fl.Module, lda: LatentDiffusionAutoencoder, - clip_text_encoder: TextEncoderInterface, + clip_text_encoder: fl.Module, scheduler: Scheduler, device: Device | str = "cpu", dtype: DType = torch.float32, @@ -43,10 +27,8 @@ class LatentDiffusionModel(fl.Module, ABC): super().__init__() self.device: Device = device if isinstance(device, Device) else Device(device=device) self.dtype = dtype - assert isinstance(unet, fl.Module) self.unet = unet.to(device=self.device, dtype=self.dtype) self.lda = lda.to(device=self.device, dtype=self.dtype) - assert isinstance(clip_text_encoder, fl.Module) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 914f35a..314fd86 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d -import refiners.foundationals.latent_diffusion.model as ldm class TimestepEncoder(fl.Passthrough): @@ -243,7 +242,7 @@ class ResidualConcatenator(fl.Chain): ) -class SD1UNet(fl.Chain, ldm.UNetInterface): +class SD1UNet(fl.Chain): structural_attrs = ["in_channels", "clip_embedding_dim"] def __init__( diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 3767f2a..4123b11 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -7,7 +7,6 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE from jaxtyping import Float from refiners.foundationals.clip.tokenizer import CLIPTokenizer -from refiners.foundationals.latent_diffusion.model import TextEncoderInterface class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): @@ -60,7 +59,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): return x[:, end_of_text_index[0], :] -class DoubleTextEncoder(fl.Chain, TextEncoderInterface): +class DoubleTextEncoder(fl.Chain): def __init__( self, text_encoder_l: CLIPTextEncoderL | None = None, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 8e98886..40b0819 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -3,7 +3,6 @@ from torch import Tensor, device as Device, dtype as DType from refiners.fluxion.context import Contexts import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion.model import UNetInterface from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( ResidualAccumulator, ResidualBlock, @@ -247,7 +246,7 @@ class OutputBlock(fl.Chain): ) -class SDXLUNet(fl.Chain, UNetInterface): +class SDXLUNet(fl.Chain): structural_attrs = ["in_channels"] def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: