From 02af8e9f0b7e6a57764f9a76e9800650cc0a6e84 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Wed, 6 Sep 2023 18:37:09 +0200 Subject: [PATCH] improve typing of ldm and sd1, introducing SD1Autoencoder class --- .../latent_diffusion/auto_encoder.py | 2 +- .../foundationals/latent_diffusion/model.py | 11 +++-------- .../stable_diffusion_1/model.py | 19 +++++++++---------- .../training_utils/latent_diffusion.py | 6 +++--- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 96fad5a..59114c6 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -197,13 +197,13 @@ class Decoder(Chain): class LatentDiffusionAutoencoder(Chain): structural_attrs = ["encoder_scale"] + encoder_scale = 0.18125 def __init__( self, device: Device | str | None = None, dtype: DType | None = None, ) -> None: - self.encoder_scale: float = 0.18215 super().__init__( Encoder(device=device, dtype=dtype), Decoder(device=device, dtype=dtype), diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 83b6789..fcdd298 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -66,19 +66,14 @@ class LatentDiffusionModel(fl.Module, ABC): return self.scheduler.steps @abstractmethod - def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None: + def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ... def forward( - self, - x: Tensor, - step: int, - clip_text_embedding: Tensor, - *args: Tensor, - condition_scale: float = 7.5, + self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor ) -> Tensor: timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) - self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args) + self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) latents = torch.cat(tensors=(x, x)) # for classifier-free guidance unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 21a1487..3bf5945 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -11,6 +11,10 @@ import numpy as np from torch import device as Device, dtype as DType, Tensor +class SD1Autoencoder(LatentDiffusionAutoencoder): + encoder_scale: float = 0.18215 + + class StableDiffusion_1(LatentDiffusionModel): unet: SD1UNet clip_text_encoder: CLIPTextEncoderL @@ -18,14 +22,14 @@ class StableDiffusion_1(LatentDiffusionModel): def __init__( self, unet: SD1UNet | None = None, - lda: LatentDiffusionAutoencoder | None = None, + lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, scheduler: Scheduler | None = None, device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: unet = unet or SD1UNet(in_channels=4) - lda = lda or LatentDiffusionAutoencoder() + lda = lda or SD1Autoencoder() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() scheduler = scheduler or DPMSolver(num_inference_steps=30) @@ -46,7 +50,7 @@ class StableDiffusion_1(LatentDiffusionModel): negative_embedding = self.clip_text_encoder(negative_text or "") return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0) - def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None: + def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: self.unet.set_timestep(timestep=timestep) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) @@ -55,7 +59,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): def __init__( self, unet: SD1UNet | None = None, - lda: LatentDiffusionAutoencoder | None = None, + lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, scheduler: Scheduler | None = None, device: Device | str = "cpu", @@ -68,12 +72,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): ) def forward( - self, - x: Tensor, - step: int, - clip_text_embedding: Tensor, - *_: Tensor, - condition_scale: float = 7.5, + self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor ) -> Tensor: assert self.mask_latents is not None assert self.target_image_latents is not None diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 4ce4d4a..9c95701 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -9,12 +9,12 @@ from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # import refiners.fluxion.layers as fl from PIL import Image from functools import cached_property +from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.training_utils.config import BaseConfig from refiners.foundationals.latent_diffusion import ( StableDiffusion_1, DPMSolver, SD1UNet, - LatentDiffusionAutoencoder, ) from refiners.foundationals.latent_diffusion.schedulers import DDPM from torch.nn.functional import mse_loss @@ -136,9 +136,9 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): return CLIPTextEncoderL(device=self.device).to(device=self.device) @cached_property - def lda(self) -> LatentDiffusionAutoencoder: + def lda(self) -> SD1Autoencoder: assert self.config.models["lda"] is not None, "The config must contain a lda entry." - return LatentDiffusionAutoencoder(device=self.device).to(device=self.device) + return SD1Autoencoder(device=self.device).to(device=self.device) def load_models(self) -> dict[str, fl.Module]: return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}