improve typing of ldm and sd1, introducing SD1Autoencoder class

This commit is contained in:
limiteinductive 2023-09-06 18:37:09 +02:00 committed by Benjamin Trom
parent 78e69c7da0
commit 02af8e9f0b
4 changed files with 16 additions and 22 deletions

View file

@ -197,13 +197,13 @@ class Decoder(Chain):
class LatentDiffusionAutoencoder(Chain): class LatentDiffusionAutoencoder(Chain):
structural_attrs = ["encoder_scale"] structural_attrs = ["encoder_scale"]
encoder_scale = 0.18125
def __init__( def __init__(
self, self,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.encoder_scale: float = 0.18215
super().__init__( super().__init__(
Encoder(device=device, dtype=dtype), Encoder(device=device, dtype=dtype),
Decoder(device=device, dtype=dtype), Decoder(device=device, dtype=dtype),

View file

@ -66,19 +66,14 @@ class LatentDiffusionModel(fl.Module, ABC):
return self.scheduler.steps return self.scheduler.steps
@abstractmethod @abstractmethod
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None: def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
... ...
def forward( def forward(
self, self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
x: Tensor,
step: int,
clip_text_embedding: Tensor,
*args: Tensor,
condition_scale: float = 7.5,
) -> Tensor: ) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args) self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)

View file

@ -11,6 +11,10 @@ import numpy as np
from torch import device as Device, dtype as DType, Tensor from torch import device as Device, dtype as DType, Tensor
class SD1Autoencoder(LatentDiffusionAutoencoder):
encoder_scale: float = 0.18215
class StableDiffusion_1(LatentDiffusionModel): class StableDiffusion_1(LatentDiffusionModel):
unet: SD1UNet unet: SD1UNet
clip_text_encoder: CLIPTextEncoderL clip_text_encoder: CLIPTextEncoderL
@ -18,14 +22,14 @@ class StableDiffusion_1(LatentDiffusionModel):
def __init__( def __init__(
self, self,
unet: SD1UNet | None = None, unet: SD1UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None, lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None, clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None, scheduler: Scheduler | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
unet = unet or SD1UNet(in_channels=4) unet = unet or SD1UNet(in_channels=4)
lda = lda or LatentDiffusionAutoencoder() lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30) scheduler = scheduler or DPMSolver(num_inference_steps=30)
@ -46,7 +50,7 @@ class StableDiffusion_1(LatentDiffusionModel):
negative_embedding = self.clip_text_encoder(negative_text or "") negative_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0) return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None: def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
self.unet.set_timestep(timestep=timestep) self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
@ -55,7 +59,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__( def __init__(
self, self,
unet: SD1UNet | None = None, unet: SD1UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None, lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None, clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None, scheduler: Scheduler | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
@ -68,12 +72,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
) )
def forward( def forward(
self, self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
x: Tensor,
step: int,
clip_text_embedding: Tensor,
*_: Tensor,
condition_scale: float = 7.5,
) -> Tensor: ) -> Tensor:
assert self.mask_latents is not None assert self.mask_latents is not None
assert self.target_image_latents is not None assert self.target_image_latents is not None

View file

@ -9,12 +9,12 @@ from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip #
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from PIL import Image from PIL import Image
from functools import cached_property from functools import cached_property
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig
from refiners.foundationals.latent_diffusion import ( from refiners.foundationals.latent_diffusion import (
StableDiffusion_1, StableDiffusion_1,
DPMSolver, DPMSolver,
SD1UNet, SD1UNet,
LatentDiffusionAutoencoder,
) )
from refiners.foundationals.latent_diffusion.schedulers import DDPM from refiners.foundationals.latent_diffusion.schedulers import DDPM
from torch.nn.functional import mse_loss from torch.nn.functional import mse_loss
@ -136,9 +136,9 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
return CLIPTextEncoderL(device=self.device).to(device=self.device) return CLIPTextEncoderL(device=self.device).to(device=self.device)
@cached_property @cached_property
def lda(self) -> LatentDiffusionAutoencoder: def lda(self) -> SD1Autoencoder:
assert self.config.models["lda"] is not None, "The config must contain a lda entry." assert self.config.models["lda"] is not None, "The config must contain a lda entry."
return LatentDiffusionAutoencoder(device=self.device).to(device=self.device) return SD1Autoencoder(device=self.device).to(device=self.device)
def load_models(self) -> dict[str, fl.Module]: def load_models(self) -> dict[str, fl.Module]:
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda} return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}