mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
improve typing of ldm and sd1, introducing SD1Autoencoder class
This commit is contained in:
parent
78e69c7da0
commit
02af8e9f0b
|
@ -197,13 +197,13 @@ class Decoder(Chain):
|
|||
|
||||
class LatentDiffusionAutoencoder(Chain):
|
||||
structural_attrs = ["encoder_scale"]
|
||||
encoder_scale = 0.18125
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.encoder_scale: float = 0.18215
|
||||
super().__init__(
|
||||
Encoder(device=device, dtype=dtype),
|
||||
Decoder(device=device, dtype=dtype),
|
||||
|
|
|
@ -66,19 +66,14 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
return self.scheduler.steps
|
||||
|
||||
@abstractmethod
|
||||
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None:
|
||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
clip_text_embedding: Tensor,
|
||||
*args: Tensor,
|
||||
condition_scale: float = 7.5,
|
||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||
) -> Tensor:
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||
|
||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
|
|
@ -11,6 +11,10 @@ import numpy as np
|
|||
from torch import device as Device, dtype as DType, Tensor
|
||||
|
||||
|
||||
class SD1Autoencoder(LatentDiffusionAutoencoder):
|
||||
encoder_scale: float = 0.18215
|
||||
|
||||
|
||||
class StableDiffusion_1(LatentDiffusionModel):
|
||||
unet: SD1UNet
|
||||
clip_text_encoder: CLIPTextEncoderL
|
||||
|
@ -18,14 +22,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
def __init__(
|
||||
self,
|
||||
unet: SD1UNet | None = None,
|
||||
lda: LatentDiffusionAutoencoder | None = None,
|
||||
lda: SD1Autoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
unet = unet or SD1UNet(in_channels=4)
|
||||
lda = lda or LatentDiffusionAutoencoder()
|
||||
lda = lda or SD1Autoencoder()
|
||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||
|
||||
|
@ -46,7 +50,7 @@ class StableDiffusion_1(LatentDiffusionModel):
|
|||
negative_embedding = self.clip_text_encoder(negative_text or "")
|
||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
||||
|
||||
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None:
|
||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||
self.unet.set_timestep(timestep=timestep)
|
||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
|
||||
|
@ -55,7 +59,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
def __init__(
|
||||
self,
|
||||
unet: SD1UNet | None = None,
|
||||
lda: LatentDiffusionAutoencoder | None = None,
|
||||
lda: SD1Autoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
|
@ -68,12 +72,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
clip_text_embedding: Tensor,
|
||||
*_: Tensor,
|
||||
condition_scale: float = 7.5,
|
||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
|
||||
) -> Tensor:
|
||||
assert self.mask_latents is not None
|
||||
assert self.target_image_latents is not None
|
||||
|
|
|
@ -9,12 +9,12 @@ from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip #
|
|||
import refiners.fluxion.layers as fl
|
||||
from PIL import Image
|
||||
from functools import cached_property
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
StableDiffusion_1,
|
||||
DPMSolver,
|
||||
SD1UNet,
|
||||
LatentDiffusionAutoencoder,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
||||
from torch.nn.functional import mse_loss
|
||||
|
@ -136,9 +136,9 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
return CLIPTextEncoderL(device=self.device).to(device=self.device)
|
||||
|
||||
@cached_property
|
||||
def lda(self) -> LatentDiffusionAutoencoder:
|
||||
def lda(self) -> SD1Autoencoder:
|
||||
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
|
||||
return LatentDiffusionAutoencoder(device=self.device).to(device=self.device)
|
||||
return SD1Autoencoder(device=self.device).to(device=self.device)
|
||||
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
|
||||
|
|
Loading…
Reference in a new issue