remove unused TextEncoder and UNet protocols

This commit is contained in:
Benjamin Trom 2023-08-25 16:00:14 +02:00
parent a5f70b6d22
commit 8b1719b1f9
5 changed files with 7 additions and 29 deletions

View file

@ -1,7 +1,6 @@
from torch import Tensor, arange, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
import refiners.foundationals.latent_diffusion.model as ldm
class TokenEncoder(fl.Embedding):
@ -122,7 +121,7 @@ class TransformerLayer(fl.Chain):
)
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface):
class CLIPTextEncoder(fl.Chain):
structural_attrs = [
"embedding_dim",
"max_sequence_length",

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Protocol, TypeVar
from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType
from PIL import Image
import torch
@ -11,31 +11,15 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
T = TypeVar("T", bound="fl.Module")
class UNetInterface(Protocol):
def set_timestep(self, timestep: Tensor) -> None:
...
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
...
def __call__(self, x: Tensor) -> Tensor:
...
class TextEncoderInterface(Protocol):
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
...
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
class LatentDiffusionModel(fl.Module, ABC):
def __init__(
self,
unet: UNetInterface,
unet: fl.Module,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: TextEncoderInterface,
clip_text_encoder: fl.Module,
scheduler: Scheduler,
device: Device | str = "cpu",
dtype: DType = torch.float32,
@ -43,10 +27,8 @@ class LatentDiffusionModel(fl.Module, ABC):
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
assert isinstance(unet, fl.Module)
self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype)
assert isinstance(clip_text_encoder, fl.Module)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)

View file

@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
import refiners.foundationals.latent_diffusion.model as ldm
class TimestepEncoder(fl.Passthrough):
@ -243,7 +242,7 @@ class ResidualConcatenator(fl.Chain):
)
class SD1UNet(fl.Chain, ldm.UNetInterface):
class SD1UNet(fl.Chain):
structural_attrs = ["in_channels", "clip_embedding_dim"]
def __init__(

View file

@ -7,7 +7,6 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
from jaxtyping import Float
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
@ -60,7 +59,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
return x[:, end_of_text_index[0], :]
class DoubleTextEncoder(fl.Chain, TextEncoderInterface):
class DoubleTextEncoder(fl.Chain):
def __init__(
self,
text_encoder_l: CLIPTextEncoderL | None = None,

View file

@ -3,7 +3,6 @@ from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.model import UNetInterface
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
ResidualAccumulator,
ResidualBlock,
@ -247,7 +246,7 @@ class OutputBlock(fl.Chain):
)
class SDXLUNet(fl.Chain, UNetInterface):
class SDXLUNet(fl.Chain):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: