Make Dataset part of the trainer

This commit is contained in:
limiteinductive 2024-02-07 14:35:26 +00:00 committed by Benjamin Trom
parent 9883f24f9a
commit 2e526d35d1
5 changed files with 71 additions and 300 deletions

View file

@ -1,11 +0,0 @@
from refiners.training_utils.latent_diffusion import FinetuneLatentDiffusionConfig, LatentDiffusionTrainer
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = FinetuneLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LatentDiffusionTrainer(config=config)
trainer.train()

View file

@ -4,6 +4,9 @@ from importlib.metadata import requires
from packaging.requirements import Requirement from packaging.requirements import Requirement
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer
refiners_requires = requires("refiners") refiners_requires = requires("refiners")
assert refiners_requires is not None assert refiners_requires is not None
@ -21,3 +24,9 @@ for dep in refiners_requires:
file=sys.stderr, file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
__all__ = [
"Trainer",
"BaseConfig",
]

View file

@ -1,260 +0,0 @@
import random
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, TypedDict, TypeVar
from loguru import logger
from PIL import Image
from pydantic import BaseModel
from torch import Generator, Tensor, cat, device as Device, dtype as DType, randn
from torch.nn import Module
from torch.nn.functional import mse_loss
from torch.utils.data import Dataset
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import (
DPMSolver,
SD1UNet,
StableDiffusion_1,
)
from refiners.foundationals.latent_diffusion.solvers import DDPM, Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.wandb import WandbLoggable
class LatentDiffusionConfig(BaseModel):
unconditional_sampling_probability: float = 0.2
offset_noise: float = 0.1
min_step: int = 0
max_step: int = 999
class TestDiffusionConfig(BaseModel):
seed: int = 0
num_inference_steps: int = 30
use_short_prompts: bool = False
prompts: list[str] = []
num_images_per_prompt: int = 1
class FinetuneLatentDiffusionConfig(BaseConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
test_diffusion: TestDiffusionConfig
@dataclass
class TextEmbeddingLatentsBatch:
text_embeddings: Tensor
latents: Tensor
class CaptionImage(TypedDict):
caption: str
image: Image.Image
ConfigType = TypeVar("ConfigType", bound=FinetuneLatentDiffusionConfig)
class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
self.trainer = trainer
self.config = trainer.config
self.device = self.trainer.device
self.lda = self.trainer.lda
self.text_encoder = self.trainer.text_encoder
self.dataset = self.load_huggingface_dataset()
self.process_image = self.build_image_processor()
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
# TODO: make this configurable and add other transforms
transforms: list[Module] = []
if self.config.dataset.random_crop:
transforms.append(RandomCrop(size=512))
if self.config.dataset.horizontal_flip:
transforms.append(RandomHorizontalFlip(p=0.5))
if not transforms:
return lambda image: image
return Compose(transforms)
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
dataset_config = self.config.dataset
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
return load_hf_dataset(
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
)
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
return resize_image(image=image, min_size=min_size, max_size=max_size)
def process_caption(self, caption: str) -> str:
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""
def get_caption(self, index: int) -> str:
return self.dataset[index]["caption"]
def get_image(self, index: int) -> Image.Image:
return self.dataset[index]["image"]
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
caption = self.get_caption(index=index)
image = self.get_image(index=index)
resized_image = self.resize_image(
image=image,
min_size=self.config.dataset.resize_image_min_size,
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
latents = self.lda.image_to_latents(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
text_embeddings = cat(tensors=[item.text_embeddings for item in batch])
latents = cat(tensors=[item.latents for item in batch])
return TextEmbeddingLatentsBatch(text_embeddings=text_embeddings, latents=latents)
def __len__(self) -> int:
return len(self.dataset)
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
@cached_property
def unet(self) -> SD1UNet:
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
return SD1UNet(in_channels=4, device=self.device).to(device=self.device)
@cached_property
def text_encoder(self) -> CLIPTextEncoderL:
assert self.config.models["text_encoder"] is not None, "The config must contain a text_encoder entry."
return CLIPTextEncoderL(device=self.device).to(device=self.device)
@cached_property
def lda(self) -> SD1Autoencoder:
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
return SD1Autoencoder(device=self.device).to(device=self.device)
def load_models(self) -> dict[str, fl.Module]:
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TextEmbeddingLatentsDataset(trainer=self)
@cached_property
def ddpm_solver(self) -> Solver:
return DDPM(
num_inference_steps=1000,
device=self.device,
).to(device=self.device)
def sample_timestep(self) -> Tensor:
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
self.current_step = random_step
return self.ddpm_solver.timesteps[random_step].unsqueeze(dim=0)
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
return sample_noise(
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
)
def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
clip_text_embedding, latents = batch.text_embeddings, batch.latents
timestep = self.sample_timestep()
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
noisy_latents = self.ddpm_solver.add_noise(x=latents, noise=noise, step=self.current_step)
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
prediction = self.unet(noisy_latents)
loss = mse_loss(input=prediction, target=noise)
return loss
def compute_evaluation(self) -> None:
sd = StableDiffusion_1(
unet=self.unet,
lda=self.lda,
clip_text_encoder=self.text_encoder,
solver=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
device=self.device,
)
prompts = self.config.test_diffusion.prompts
num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt
if self.config.test_diffusion.use_short_prompts:
prompts = [prompt.split(sep=",")[0] for prompt in prompts]
images: dict[str, WandbLoggable] = {}
for prompt in prompts:
canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt))
for i in range(num_images_per_prompt):
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
x = randn(1, 4, 64, 64, device=self.device)
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device)
for step in sd.steps:
x = sd(
x,
step=step,
clip_text_embedding=clip_text_embedding,
)
canvas_image.paste(sd.lda.latents_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
# TODO: wandb_log images
def sample_noise(
size: tuple[int, ...],
offset_noise: float = 0.1,
device: Device | str = "cpu",
dtype: DType | None = None,
generator: Generator | None = None,
) -> Tensor:
"""Sample noise from a normal distribution.
If `offset_noise` is more than 0, the noise will be offset by a small amount. It allows the model to generate
images with a wider range of contrast https://www.crosslabs.org/blog/diffusion-with-offset-noise.
"""
device = Device(device)
noise = randn(*size, generator=generator, device=device, dtype=dtype)
return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, device=device, dtype=dtype)
def resize_image(image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
image_min_size = min(image.size)
if image_min_size > max_size:
if image_min_size == image.size[0]:
image = image.resize(size=(max_size, int(max_size * image.size[1] / image.size[0])))
else:
image = image.resize(size=(int(max_size * image.size[0] / image.size[1]), max_size))
if image_min_size < min_size:
if image_min_size == image.size[0]:
image = image.resize(size=(min_size, int(min_size * image.size[1] / image.size[0])))
else:
image = image.resize(size=(int(min_size * image.size[0] / image.size[1]), min_size))
return image
class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
def on_train_begin(self, trainer: LatentDiffusionTrainer[Any]) -> None:
self.timestep_bins: dict[int, list[float]] = {i: [] for i in range(10)}
def on_compute_loss_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
loss_value = trainer.loss.detach().cpu().item()
current_step = trainer.current_step
bin_index = min(current_step // 100, 9)
self.timestep_bins[bin_index].append(loss_value)
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
log_data: dict[str, WandbLoggable] = {}
for bin_index, losses in self.timestep_bins.items():
if losses:
avg_loss = sum(losses) / len(losses)
log_data[f"average_loss_timestep_bin_{bin_index * 100}"] = avg_loss
self.timestep_bins[bin_index] = []
# TODO: wandb_log log_data

View file

@ -1,6 +1,6 @@
import random import random
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod, abstractproperty
from functools import cached_property, wraps from functools import cached_property, wraps
from typing import Any, Callable, Generic, Iterable, TypeVar, cast from typing import Any, Callable, Generic, Iterable, TypeVar, cast
@ -250,6 +250,23 @@ Batch = TypeVar("Batch")
ConfigType = TypeVar("ConfigType", bound=BaseConfig) ConfigType = TypeVar("ConfigType", bound=BaseConfig)
class _Dataset(Dataset[Batch]):
"""
A wrapper around the `get_item` method to create a [`torch.utils.data.Dataset`]
"""
def __init__(self, get_item: Callable[[int], Batch], length: int) -> None:
assert length > 0, "Dataset length must be greater than 0."
self.length = length
self.get_item = get_item
def __getitem__(self, index: int) -> Batch:
return self.get_item(index)
def __len__(self) -> int:
return self.length
class Trainer(Generic[ConfigType, Batch], ABC): class Trainer(Generic[ConfigType, Batch], ABC):
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
self.config = config self.config = config
@ -429,23 +446,44 @@ class Trainer(Generic[ConfigType, Batch], ABC):
... ...
@abstractmethod @abstractmethod
def load_dataset(self) -> Dataset[Batch]: def get_item(self, _: int) -> Batch:
"""
Returns a batch of data.
This function is used by the dataloader to fetch a batch of data.
"""
...
@abstractproperty
def dataset_length(self) -> int:
"""
Returns the length of the dataset.
This is used to compute the number of batches per epoch.
"""
...
@abstractmethod
def collate_fn(self, batch: list[Batch]) -> Batch:
"""
Collate function for the dataloader.
This function is used to tell the dataloader how to combine a list of
batches into a single batch.
"""
... ...
@cached_property @cached_property
def dataset(self) -> Dataset[Batch]: def dataset(self) -> Dataset[Batch]:
return self.load_dataset() """
Returns the dataset constructed with the `get_item` method.
"""
return _Dataset(get_item=self.get_item, length=self.dataset_length)
@cached_property @cached_property
def dataset_length(self) -> int: def dataloader(self) -> DataLoader[Any]:
assert hasattr(self.dataset, "__len__"), "The dataset must implement the `__len__` method."
return len(self.dataset) # type: ignore
@cached_property
def dataloader(self) -> DataLoader[Batch]:
collate_fn = getattr(self.dataset, "collate_fn", None)
return DataLoader( return DataLoader(
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn
) )
@abstractmethod @abstractmethod

View file

@ -7,7 +7,6 @@ import pytest
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.optim import SGD from torch.optim import SGD
from torch.utils.data import Dataset
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm from refiners.fluxion.utils import norm
@ -27,20 +26,6 @@ class MockBatch:
targets: torch.Tensor targets: torch.Tensor
class MockDataset(Dataset[MockBatch]):
def __len__(self):
return 20
def __getitem__(self, _: int) -> MockBatch:
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))
def collate_fn(self, batch: list[MockBatch]) -> MockBatch:
return MockBatch(
inputs=torch.cat([b.inputs for b in batch]),
targets=torch.cat([b.targets for b in batch]),
)
class MockConfig(BaseConfig): class MockConfig(BaseConfig):
pass pass
@ -57,13 +42,23 @@ class MockModel(fl.Chain):
class MockTrainer(Trainer[MockConfig, MockBatch]): class MockTrainer(Trainer[MockConfig, MockBatch]):
step_counter: int = 0 step_counter: int = 0
@property
def dataset_length(self) -> int:
return 20
def get_item(self, _: int) -> MockBatch:
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))
def collate_fn(self, batch: list[MockBatch]) -> MockBatch:
return MockBatch(
inputs=torch.cat([b.inputs for b in batch]),
targets=torch.cat([b.targets for b in batch]),
)
@cached_property @cached_property
def mock_model(self) -> MockModel: def mock_model(self) -> MockModel:
return MockModel() return MockModel()
def load_dataset(self) -> Dataset[MockBatch]:
return MockDataset()
def load_models(self) -> dict[str, fl.Module]: def load_models(self) -> dict[str, fl.Module]:
return {"mock_model": self.mock_model} return {"mock_model": self.mock_model}