mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
remove unused TextEncoder and UNet protocols
This commit is contained in:
parent
a5f70b6d22
commit
8b1719b1f9
|
@ -1,7 +1,6 @@
|
||||||
from torch import Tensor, arange, device as Device, dtype as DType
|
from torch import Tensor, arange, device as Device, dtype as DType
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
import refiners.foundationals.latent_diffusion.model as ldm
|
|
||||||
|
|
||||||
|
|
||||||
class TokenEncoder(fl.Embedding):
|
class TokenEncoder(fl.Embedding):
|
||||||
|
@ -122,7 +121,7 @@ class TransformerLayer(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface):
|
class CLIPTextEncoder(fl.Chain):
|
||||||
structural_attrs = [
|
structural_attrs = [
|
||||||
"embedding_dim",
|
"embedding_dim",
|
||||||
"max_sequence_length",
|
"max_sequence_length",
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Protocol, TypeVar
|
from typing import TypeVar
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
|
@ -11,31 +11,15 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
|
||||||
T = TypeVar("T", bound="fl.Module")
|
T = TypeVar("T", bound="fl.Module")
|
||||||
|
|
||||||
|
|
||||||
class UNetInterface(Protocol):
|
|
||||||
def set_timestep(self, timestep: Tensor) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderInterface(Protocol):
|
|
||||||
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionModel(fl.Module, ABC):
|
class LatentDiffusionModel(fl.Module, ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet: UNetInterface,
|
unet: fl.Module,
|
||||||
lda: LatentDiffusionAutoencoder,
|
lda: LatentDiffusionAutoencoder,
|
||||||
clip_text_encoder: TextEncoderInterface,
|
clip_text_encoder: fl.Module,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
|
@ -43,10 +27,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
assert isinstance(unet, fl.Module)
|
|
||||||
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||||
assert isinstance(clip_text_encoder, fl.Module)
|
|
||||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
||||||
import refiners.foundationals.latent_diffusion.model as ldm
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(fl.Passthrough):
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
@ -243,7 +242,7 @@ class ResidualConcatenator(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SD1UNet(fl.Chain, ldm.UNetInterface):
|
class SD1UNet(fl.Chain):
|
||||||
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -7,7 +7,6 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
|
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
|
@ -60,7 +59,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
return x[:, end_of_text_index[0], :]
|
return x[:, end_of_text_index[0], :]
|
||||||
|
|
||||||
|
|
||||||
class DoubleTextEncoder(fl.Chain, TextEncoderInterface):
|
class DoubleTextEncoder(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_encoder_l: CLIPTextEncoderL | None = None,
|
text_encoder_l: CLIPTextEncoderL | None = None,
|
||||||
|
|
|
@ -3,7 +3,6 @@ from torch import Tensor, device as Device, dtype as DType
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.foundationals.latent_diffusion.model import UNetInterface
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
ResidualAccumulator,
|
ResidualAccumulator,
|
||||||
ResidualBlock,
|
ResidualBlock,
|
||||||
|
@ -247,7 +246,7 @@ class OutputBlock(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLUNet(fl.Chain, UNetInterface):
|
class SDXLUNet(fl.Chain):
|
||||||
structural_attrs = ["in_channels"]
|
structural_attrs = ["in_channels"]
|
||||||
|
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
|
Loading…
Reference in a new issue