remove unused TextEncoder and UNet protocols

This commit is contained in:
Benjamin Trom 2023-08-25 16:00:14 +02:00
parent a5f70b6d22
commit 8b1719b1f9
5 changed files with 7 additions and 29 deletions

View file

@ -1,7 +1,6 @@
from torch import Tensor, arange, device as Device, dtype as DType from torch import Tensor, arange, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
import refiners.foundationals.latent_diffusion.model as ldm
class TokenEncoder(fl.Embedding): class TokenEncoder(fl.Embedding):
@ -122,7 +121,7 @@ class TransformerLayer(fl.Chain):
) )
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface): class CLIPTextEncoder(fl.Chain):
structural_attrs = [ structural_attrs = [
"embedding_dim", "embedding_dim",
"max_sequence_length", "max_sequence_length",

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Protocol, TypeVar from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from PIL import Image from PIL import Image
import torch import torch
@ -11,31 +11,15 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
T = TypeVar("T", bound="fl.Module") T = TypeVar("T", bound="fl.Module")
class UNetInterface(Protocol):
def set_timestep(self, timestep: Tensor) -> None:
...
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
...
def __call__(self, x: Tensor) -> Tensor:
...
class TextEncoderInterface(Protocol):
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
...
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
class LatentDiffusionModel(fl.Module, ABC): class LatentDiffusionModel(fl.Module, ABC):
def __init__( def __init__(
self, self,
unet: UNetInterface, unet: fl.Module,
lda: LatentDiffusionAutoencoder, lda: LatentDiffusionAutoencoder,
clip_text_encoder: TextEncoderInterface, clip_text_encoder: fl.Module,
scheduler: Scheduler, scheduler: Scheduler,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
@ -43,10 +27,8 @@ class LatentDiffusionModel(fl.Module, ABC):
super().__init__() super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device) self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype self.dtype = dtype
assert isinstance(unet, fl.Module)
self.unet = unet.to(device=self.device, dtype=self.dtype) self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype) self.lda = lda.to(device=self.device, dtype=self.dtype)
assert isinstance(clip_text_encoder, fl.Module)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)

View file

@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
import refiners.foundationals.latent_diffusion.model as ldm
class TimestepEncoder(fl.Passthrough): class TimestepEncoder(fl.Passthrough):
@ -243,7 +242,7 @@ class ResidualConcatenator(fl.Chain):
) )
class SD1UNet(fl.Chain, ldm.UNetInterface): class SD1UNet(fl.Chain):
structural_attrs = ["in_channels", "clip_embedding_dim"] structural_attrs = ["in_channels", "clip_embedding_dim"]
def __init__( def __init__(

View file

@ -7,7 +7,6 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
from jaxtyping import Float from jaxtyping import Float
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]): class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
@ -60,7 +59,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
return x[:, end_of_text_index[0], :] return x[:, end_of_text_index[0], :]
class DoubleTextEncoder(fl.Chain, TextEncoderInterface): class DoubleTextEncoder(fl.Chain):
def __init__( def __init__(
self, self,
text_encoder_l: CLIPTextEncoderL | None = None, text_encoder_l: CLIPTextEncoderL | None = None,

View file

@ -3,7 +3,6 @@ from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.model import UNetInterface
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
ResidualAccumulator, ResidualAccumulator,
ResidualBlock, ResidualBlock,
@ -247,7 +246,7 @@ class OutputBlock(fl.Chain):
) )
class SDXLUNet(fl.Chain, UNetInterface): class SDXLUNet(fl.Chain):
structural_attrs = ["in_channels"] structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: