diff --git a/scripts/training/finetune-ldm.py b/scripts/training/finetune-ldm.py deleted file mode 100644 index c1e7808..0000000 --- a/scripts/training/finetune-ldm.py +++ /dev/null @@ -1,11 +0,0 @@ -from refiners.training_utils.latent_diffusion import FinetuneLatentDiffusionConfig, LatentDiffusionTrainer - -if __name__ == "__main__": - import sys - - config_path = sys.argv[1] - config = FinetuneLatentDiffusionConfig.load_from_toml( - toml_path=config_path, - ) - trainer = LatentDiffusionTrainer(config=config) - trainer.train() diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index f4c9950..617a8b4 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -4,6 +4,9 @@ from importlib.metadata import requires from packaging.requirements import Requirement +from refiners.training_utils.config import BaseConfig +from refiners.training_utils.trainer import Trainer + refiners_requires = requires("refiners") assert refiners_requires is not None @@ -21,3 +24,9 @@ for dep in refiners_requires: file=sys.stderr, ) sys.exit(1) + + +__all__ = [ + "Trainer", + "BaseConfig", +] diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py deleted file mode 100644 index 9600ce6..0000000 --- a/src/refiners/training_utils/latent_diffusion.py +++ /dev/null @@ -1,260 +0,0 @@ -import random -from dataclasses import dataclass -from functools import cached_property -from typing import Any, Callable, TypedDict, TypeVar - -from loguru import logger -from PIL import Image -from pydantic import BaseModel -from torch import Generator, Tensor, cat, device as Device, dtype as DType, randn -from torch.nn import Module -from torch.nn.functional import mse_loss -from torch.utils.data import Dataset -from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore - -import refiners.fluxion.layers as fl -from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from refiners.foundationals.latent_diffusion import ( - DPMSolver, - SD1UNet, - StableDiffusion_1, -) -from refiners.foundationals.latent_diffusion.solvers import DDPM, Solver -from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder -from refiners.training_utils.callback import Callback -from refiners.training_utils.config import BaseConfig -from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset -from refiners.training_utils.trainer import Trainer -from refiners.training_utils.wandb import WandbLoggable - - -class LatentDiffusionConfig(BaseModel): - unconditional_sampling_probability: float = 0.2 - offset_noise: float = 0.1 - min_step: int = 0 - max_step: int = 999 - - -class TestDiffusionConfig(BaseModel): - seed: int = 0 - num_inference_steps: int = 30 - use_short_prompts: bool = False - prompts: list[str] = [] - num_images_per_prompt: int = 1 - - -class FinetuneLatentDiffusionConfig(BaseConfig): - dataset: HuggingfaceDatasetConfig - latent_diffusion: LatentDiffusionConfig - test_diffusion: TestDiffusionConfig - - -@dataclass -class TextEmbeddingLatentsBatch: - text_embeddings: Tensor - latents: Tensor - - -class CaptionImage(TypedDict): - caption: str - image: Image.Image - - -ConfigType = TypeVar("ConfigType", bound=FinetuneLatentDiffusionConfig) - - -class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): - def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None: - self.trainer = trainer - self.config = trainer.config - self.device = self.trainer.device - self.lda = self.trainer.lda - self.text_encoder = self.trainer.text_encoder - self.dataset = self.load_huggingface_dataset() - self.process_image = self.build_image_processor() - logger.info(f"Loaded {len(self.dataset)} samples from dataset") - - def build_image_processor(self) -> Callable[[Image.Image], Image.Image]: - # TODO: make this configurable and add other transforms - transforms: list[Module] = [] - if self.config.dataset.random_crop: - transforms.append(RandomCrop(size=512)) - if self.config.dataset.horizontal_flip: - transforms.append(RandomHorizontalFlip(p=0.5)) - if not transforms: - return lambda image: image - return Compose(transforms) - - def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]: - dataset_config = self.config.dataset - logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}") - return load_hf_dataset( - path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split - ) - - def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image: - return resize_image(image=image, min_size=min_size, max_size=max_size) - - def process_caption(self, caption: str) -> str: - return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else "" - - def get_caption(self, index: int) -> str: - return self.dataset[index]["caption"] - - def get_image(self, index: int) -> Image.Image: - return self.dataset[index]["image"] - - def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch: - caption = self.get_caption(index=index) - image = self.get_image(index=index) - resized_image = self.resize_image( - image=image, - min_size=self.config.dataset.resize_image_min_size, - max_size=self.config.dataset.resize_image_max_size, - ) - processed_image = self.process_image(resized_image) - latents = self.lda.image_to_latents(image=processed_image).to(device=self.device) - processed_caption = self.process_caption(caption=caption) - clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device) - return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents) - - def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch: - text_embeddings = cat(tensors=[item.text_embeddings for item in batch]) - latents = cat(tensors=[item.latents for item in batch]) - return TextEmbeddingLatentsBatch(text_embeddings=text_embeddings, latents=latents) - - def __len__(self) -> int: - return len(self.dataset) - - -class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): - @cached_property - def unet(self) -> SD1UNet: - assert self.config.models["unet"] is not None, "The config must contain a unet entry." - return SD1UNet(in_channels=4, device=self.device).to(device=self.device) - - @cached_property - def text_encoder(self) -> CLIPTextEncoderL: - assert self.config.models["text_encoder"] is not None, "The config must contain a text_encoder entry." - return CLIPTextEncoderL(device=self.device).to(device=self.device) - - @cached_property - def lda(self) -> SD1Autoencoder: - assert self.config.models["lda"] is not None, "The config must contain a lda entry." - return SD1Autoencoder(device=self.device).to(device=self.device) - - def load_models(self) -> dict[str, fl.Module]: - return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda} - - def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: - return TextEmbeddingLatentsDataset(trainer=self) - - @cached_property - def ddpm_solver(self) -> Solver: - return DDPM( - num_inference_steps=1000, - device=self.device, - ).to(device=self.device) - - def sample_timestep(self) -> Tensor: - random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) - self.current_step = random_step - return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0) - - def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: - return sample_noise( - size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype - ) - - def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: - clip_text_embedding, latents = batch.text_embeddings, batch.latents - timestep = self.sample_timestep() - noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) - noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step) - self.unet.set_timestep(timestep=timestep) - self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) - prediction = self.unet(noisy_latents) - loss = mse_loss(input=prediction, target=noise) - return loss - - def compute_evaluation(self) -> None: - sd = StableDiffusion_1( - unet=self.unet, - lda=self.lda, - clip_text_encoder=self.text_encoder, - solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), - device=self.device, - ) - prompts = self.config.test_diffusion.prompts - num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt - if self.config.test_diffusion.use_short_prompts: - prompts = [prompt.split(sep=",")[0] for prompt in prompts] - images: dict[str, WandbLoggable] = {} - for prompt in prompts: - canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt)) - for i in range(num_images_per_prompt): - logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}") - x = randn(1, 4, 64, 64, device=self.device) - clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device) - for step in sd.steps: - x = sd( - x, - step=step, - clip_text_embedding=clip_text_embedding, - ) - canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i)) - images[prompt] = canvas_image - # TODO: wandb_log images - - -def sample_noise( - size: tuple[int, ...], - offset_noise: float = 0.1, - device: Device | str = "cpu", - dtype: DType | None = None, - generator: Generator | None = None, -) -> Tensor: - """Sample noise from a normal distribution. - - If `offset_noise` is more than 0, the noise will be offset by a small amount. It allows the model to generate - images with a wider range of contrast https://www.crosslabs.org/blog/diffusion-with-offset-noise. - """ - device = Device(device) - noise = randn(*size, generator=generator, device=device, dtype=dtype) - return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, device=device, dtype=dtype) - - -def resize_image(image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image: - image_min_size = min(image.size) - if image_min_size > max_size: - if image_min_size == image.size[0]: - image = image.resize(size=(max_size, int(max_size * image.size[1] / image.size[0]))) - else: - image = image.resize(size=(int(max_size * image.size[0] / image.size[1]), max_size)) - if image_min_size < min_size: - if image_min_size == image.size[0]: - image = image.resize(size=(min_size, int(min_size * image.size[1] / image.size[0]))) - else: - image = image.resize(size=(int(min_size * image.size[0] / image.size[1]), min_size)) - return image - - -class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]): - def on_train_begin(self, trainer: LatentDiffusionTrainer[Any]) -> None: - self.timestep_bins: dict[int, list[float]] = {i: [] for i in range(10)} - - def on_compute_loss_end(self, trainer: LatentDiffusionTrainer[Any]) -> None: - loss_value = trainer.loss.detach().cpu().item() - current_step = trainer.current_step - bin_index = min(current_step // 100, 9) - self.timestep_bins[bin_index].append(loss_value) - - def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None: - log_data: dict[str, WandbLoggable] = {} - for bin_index, losses in self.timestep_bins.items(): - if losses: - avg_loss = sum(losses) / len(losses) - log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss - self.timestep_bins[bin_index] = [] - - # TODO: wandb_log log_data diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 0d0e72d..86c9e55 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -1,6 +1,6 @@ import random import time -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from functools import cached_property, wraps from typing import Any, Callable, Generic, Iterable, TypeVar, cast @@ -250,6 +250,23 @@ Batch = TypeVar("Batch") ConfigType = TypeVar("ConfigType", bound=BaseConfig) +class _Dataset(Dataset[Batch]): + """ + A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`] + """ + + def __init__(self, get_item: Callable[[int], Batch], length: int) -> None: + assert length > 0, "Dataset length must be greater than 0." + self.length = length + self.get_item = get_item + + def __getitem__(self, index: int) -> Batch: + return self.get_item(index) + + def __len__(self) -> int: + return self.length + + class Trainer(Generic[ConfigType, Batch], ABC): def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: self.config = config @@ -429,23 +446,44 @@ class Trainer(Generic[ConfigType, Batch], ABC): ... @abstractmethod - def load_dataset(self) -> Dataset[Batch]: + def get_item(self, _: int) -> Batch: + """ + Returns a batch of data. + + This function is used by the dataloader to fetch a batch of data. + """ + ... + + @abstractproperty + def dataset_length(self) -> int: + """ + Returns the length of the dataset. + + This is used to compute the number of batches per epoch. + """ + ... + + @abstractmethod + def collate_fn(self, batch: list[Batch]) -> Batch: + """ + Collate function for the dataloader. + + This function is used to tell the dataloader how to combine a list of + batches into a single batch. + """ ... @cached_property def dataset(self) -> Dataset[Batch]: - return self.load_dataset() + """ + Returns the dataset constructed with the `get_item` method. + """ + return _Dataset(get_item=self.get_item, length=self.dataset_length) @cached_property - def dataset_length(self) -> int: - assert hasattr(self.dataset, "__len__"), "The dataset must implement the `__len__` method." - return len(self.dataset) # type: ignore - - @cached_property - def dataloader(self) -> DataLoader[Batch]: - collate_fn = getattr(self.dataset, "collate_fn", None) + def dataloader(self) -> DataLoader[Any]: return DataLoader( - dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn + dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn ) @abstractmethod diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 505c0df..26edd7c 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -7,7 +7,6 @@ import pytest import torch from torch import Tensor, nn from torch.optim import SGD -from torch.utils.data import Dataset from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm @@ -27,20 +26,6 @@ class MockBatch: targets: torch.Tensor -class MockDataset(Dataset[MockBatch]): - def __len__(self): - return 20 - - def __getitem__(self, _: int) -> MockBatch: - return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10)) - - def collate_fn(self, batch: list[MockBatch]) -> MockBatch: - return MockBatch( - inputs=torch.cat([b.inputs for b in batch]), - targets=torch.cat([b.targets for b in batch]), - ) - - class MockConfig(BaseConfig): pass @@ -57,13 +42,23 @@ class MockModel(fl.Chain): class MockTrainer(Trainer[MockConfig, MockBatch]): step_counter: int = 0 + @property + def dataset_length(self) -> int: + return 20 + + def get_item(self, _: int) -> MockBatch: + return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10)) + + def collate_fn(self, batch: list[MockBatch]) -> MockBatch: + return MockBatch( + inputs=torch.cat([b.inputs for b in batch]), + targets=torch.cat([b.targets for b in batch]), + ) + @cached_property def mock_model(self) -> MockModel: return MockModel() - def load_dataset(self) -> Dataset[MockBatch]: - return MockDataset() - def load_models(self) -> dict[str, fl.Module]: return {"mock_model": self.mock_model}