mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
improve typing of ldm and sd1, introducing SD1Autoencoder class
This commit is contained in:
parent
78e69c7da0
commit
02af8e9f0b
|
@ -197,13 +197,13 @@ class Decoder(Chain):
|
||||||
|
|
||||||
class LatentDiffusionAutoencoder(Chain):
|
class LatentDiffusionAutoencoder(Chain):
|
||||||
structural_attrs = ["encoder_scale"]
|
structural_attrs = ["encoder_scale"]
|
||||||
|
encoder_scale = 0.18125
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.encoder_scale: float = 0.18215
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Encoder(device=device, dtype=dtype),
|
Encoder(device=device, dtype=dtype),
|
||||||
Decoder(device=device, dtype=dtype),
|
Decoder(device=device, dtype=dtype),
|
||||||
|
|
|
@ -66,19 +66,14 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
return self.scheduler.steps
|
return self.scheduler.steps
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
x: Tensor,
|
|
||||||
step: int,
|
|
||||||
clip_text_embedding: Tensor,
|
|
||||||
*args: Tensor,
|
|
||||||
condition_scale: float = 7.5,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args)
|
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||||
|
|
||||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
|
@ -11,6 +11,10 @@ import numpy as np
|
||||||
from torch import device as Device, dtype as DType, Tensor
|
from torch import device as Device, dtype as DType, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SD1Autoencoder(LatentDiffusionAutoencoder):
|
||||||
|
encoder_scale: float = 0.18215
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion_1(LatentDiffusionModel):
|
class StableDiffusion_1(LatentDiffusionModel):
|
||||||
unet: SD1UNet
|
unet: SD1UNet
|
||||||
clip_text_encoder: CLIPTextEncoderL
|
clip_text_encoder: CLIPTextEncoderL
|
||||||
|
@ -18,14 +22,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet: SD1UNet | None = None,
|
unet: SD1UNet | None = None,
|
||||||
lda: LatentDiffusionAutoencoder | None = None,
|
lda: SD1Autoencoder | None = None,
|
||||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
scheduler: Scheduler | None = None,
|
scheduler: Scheduler | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
unet = unet or SD1UNet(in_channels=4)
|
unet = unet or SD1UNet(in_channels=4)
|
||||||
lda = lda or LatentDiffusionAutoencoder()
|
lda = lda or SD1Autoencoder()
|
||||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||||
|
|
||||||
|
@ -46,7 +50,7 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
negative_embedding = self.clip_text_encoder(negative_text or "")
|
negative_embedding = self.clip_text_encoder(negative_text or "")
|
||||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
||||||
|
|
||||||
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
self.unet.set_timestep(timestep=timestep)
|
self.unet.set_timestep(timestep=timestep)
|
||||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
|
@ -55,7 +59,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet: SD1UNet | None = None,
|
unet: SD1UNet | None = None,
|
||||||
lda: LatentDiffusionAutoencoder | None = None,
|
lda: SD1Autoencoder | None = None,
|
||||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
scheduler: Scheduler | None = None,
|
scheduler: Scheduler | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
|
@ -68,12 +72,7 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
|
||||||
x: Tensor,
|
|
||||||
step: int,
|
|
||||||
clip_text_embedding: Tensor,
|
|
||||||
*_: Tensor,
|
|
||||||
condition_scale: float = 7.5,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert self.mask_latents is not None
|
assert self.mask_latents is not None
|
||||||
assert self.target_image_latents is not None
|
assert self.target_image_latents is not None
|
||||||
|
|
|
@ -9,12 +9,12 @@ from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip #
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
StableDiffusion_1,
|
StableDiffusion_1,
|
||||||
DPMSolver,
|
DPMSolver,
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
LatentDiffusionAutoencoder,
|
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
||||||
from torch.nn.functional import mse_loss
|
from torch.nn.functional import mse_loss
|
||||||
|
@ -136,9 +136,9 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
return CLIPTextEncoderL(device=self.device).to(device=self.device)
|
return CLIPTextEncoderL(device=self.device).to(device=self.device)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def lda(self) -> LatentDiffusionAutoencoder:
|
def lda(self) -> SD1Autoencoder:
|
||||||
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
|
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
|
||||||
return LatentDiffusionAutoencoder(device=self.device).to(device=self.device)
|
return SD1Autoencoder(device=self.device).to(device=self.device)
|
||||||
|
|
||||||
def load_models(self) -> dict[str, fl.Module]:
|
def load_models(self) -> dict[str, fl.Module]:
|
||||||
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
|
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
|
||||||
|
|
Loading…
Reference in a new issue