diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 7b20f35..ed24453 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -208,6 +208,8 @@ class HuggingfaceDatasetConfig(BaseModel): hf_repo: str = "finegrain/unsplash-dummy" revision: str = "main" split: str = "train" + horizontal_flip: bool = False + random_crop: bool = True use_verification: bool = False diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index f4f6d03..eee6d0f 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import Any, TypeVar, TypedDict, cast +from typing import Any, TypeVar, TypedDict, Callable from pydantic import BaseModel from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat from loguru import logger from torch.utils.data import Dataset from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from torchvision.transforms import RandomCrop # type: ignore +from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore import refiners.fluxion.layers as fl from PIL import Image from functools import cached_property @@ -23,6 +23,7 @@ from refiners.training_utils.wandb import WandbLoggable from refiners.training_utils.trainer import Trainer from refiners.training_utils.callback import Callback from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset +from torch.nn import Module class LatentDiffusionConfig(BaseModel): @@ -67,15 +68,25 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): self.lda = self.trainer.lda self.text_encoder = self.trainer.text_encoder self.dataset = self.load_huggingface_dataset() - self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms + self.process_image = self.build_image_processor() logger.info(f"Loaded {len(self.dataset)} samples from dataset") + def build_image_processor(self) -> Callable[[Image.Image], Image.Image]: + # TODO: make this configurable and add other transforms + transforms: list[Module] = [] + if self.config.dataset.random_crop: + transforms.append(RandomCrop(size=512)) + if self.config.dataset.horizontal_flip: + transforms.append(RandomHorizontalFlip(p=0.5)) + if not transforms: + return lambda image: image + return Compose(transforms) + def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]: dataset_config = self.config.dataset logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}") - return cast( - HuggingfaceDataset[CaptionImage], - load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split), + return load_hf_dataset( + path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split ) def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image: @@ -88,7 +99,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): item = self.dataset[index] caption, image = item["caption"], item["image"] resized_image = self.resize_image(image=image) - processed_image: Image.Image = self.process_image(resized_image) + processed_image = self.process_image(resized_image) latents = self.lda.encode_image(image=processed_image).to(device=self.device) processed_caption = self.process_caption(caption=caption) clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)