mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove unused TextEncoder and UNet protocols
This commit is contained in:
parent
a5f70b6d22
commit
8b1719b1f9
|
@ -1,7 +1,6 @@
|
|||
from torch import Tensor, arange, device as Device, dtype as DType
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
import refiners.foundationals.latent_diffusion.model as ldm
|
||||
|
||||
|
||||
class TokenEncoder(fl.Embedding):
|
||||
|
@ -122,7 +121,7 @@ class TransformerLayer(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface):
|
||||
class CLIPTextEncoder(fl.Chain):
|
||||
structural_attrs = [
|
||||
"embedding_dim",
|
||||
"max_sequence_length",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Protocol, TypeVar
|
||||
from typing import TypeVar
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
@ -11,31 +11,15 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
|
|||
T = TypeVar("T", bound="fl.Module")
|
||||
|
||||
|
||||
class UNetInterface(Protocol):
|
||||
def set_timestep(self, timestep: Tensor) -> None:
|
||||
...
|
||||
|
||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||
...
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
class TextEncoderInterface(Protocol):
|
||||
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
|
||||
...
|
||||
|
||||
|
||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||
|
||||
|
||||
class LatentDiffusionModel(fl.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNetInterface,
|
||||
unet: fl.Module,
|
||||
lda: LatentDiffusionAutoencoder,
|
||||
clip_text_encoder: TextEncoderInterface,
|
||||
clip_text_encoder: fl.Module,
|
||||
scheduler: Scheduler,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
|
@ -43,10 +27,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
super().__init__()
|
||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||
self.dtype = dtype
|
||||
assert isinstance(unet, fl.Module)
|
||||
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||
assert isinstance(clip_text_encoder, fl.Module)
|
||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
|
|||
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
||||
import refiners.foundationals.latent_diffusion.model as ldm
|
||||
|
||||
|
||||
class TimestepEncoder(fl.Passthrough):
|
||||
|
@ -243,7 +242,7 @@ class ResidualConcatenator(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class SD1UNet(fl.Chain, ldm.UNetInterface):
|
||||
class SD1UNet(fl.Chain):
|
||||
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -7,7 +7,6 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
|
|||
from jaxtyping import Float
|
||||
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
|
||||
|
||||
|
||||
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||
|
@ -60,7 +59,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
|||
return x[:, end_of_text_index[0], :]
|
||||
|
||||
|
||||
class DoubleTextEncoder(fl.Chain, TextEncoderInterface):
|
||||
class DoubleTextEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder_l: CLIPTextEncoderL | None = None,
|
||||
|
|
|
@ -3,7 +3,6 @@ from torch import Tensor, device as Device, dtype as DType
|
|||
from refiners.fluxion.context import Contexts
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion.model import UNetInterface
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||
ResidualAccumulator,
|
||||
ResidualBlock,
|
||||
|
@ -247,7 +246,7 @@ class OutputBlock(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class SDXLUNet(fl.Chain, UNetInterface):
|
||||
class SDXLUNet(fl.Chain):
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
|
|
Loading…
Reference in a new issue