From 437fa2436831c30019f4720abfa163336a40c846 Mon Sep 17 00:00:00 2001 From: Doryan Kaced Date: Wed, 30 Aug 2023 10:52:23 +0200 Subject: [PATCH] Make horizontal flipping parametrable in training scripts --- src/refiners/training_utils/config.py | 2 ++ .../training_utils/latent_diffusion.py | 25 +++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 7b20f35..ed24453 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -208,6 +208,8 @@ class HuggingfaceDatasetConfig(BaseModel): hf_repo: str = "finegrain/unsplash-dummy" revision: str = "main" split: str = "train" + horizontal_flip: bool = False + random_crop: bool = True use_verification: bool = False diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index f4f6d03..eee6d0f 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import Any, TypeVar, TypedDict, cast +from typing import Any, TypeVar, TypedDict, Callable from pydantic import BaseModel from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat from loguru import logger from torch.utils.data import Dataset from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL -from torchvision.transforms import RandomCrop # type: ignore +from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore import refiners.fluxion.layers as fl from PIL import Image from functools import cached_property @@ -23,6 +23,7 @@ from refiners.training_utils.wandb import WandbLoggable from refiners.training_utils.trainer import Trainer from refiners.training_utils.callback import Callback from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset +from torch.nn import Module class LatentDiffusionConfig(BaseModel): @@ -67,15 +68,25 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): self.lda = self.trainer.lda self.text_encoder = self.trainer.text_encoder self.dataset = self.load_huggingface_dataset() - self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms + self.process_image = self.build_image_processor() logger.info(f"Loaded {len(self.dataset)} samples from dataset") + def build_image_processor(self) -> Callable[[Image.Image], Image.Image]: + # TODO: make this configurable and add other transforms + transforms: list[Module] = [] + if self.config.dataset.random_crop: + transforms.append(RandomCrop(size=512)) + if self.config.dataset.horizontal_flip: + transforms.append(RandomHorizontalFlip(p=0.5)) + if not transforms: + return lambda image: image + return Compose(transforms) + def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]: dataset_config = self.config.dataset logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}") - return cast( - HuggingfaceDataset[CaptionImage], - load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split), + return load_hf_dataset( + path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split ) def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image: @@ -88,7 +99,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): item = self.dataset[index] caption, image = item["caption"], item["image"] resized_image = self.resize_image(image=image) - processed_image: Image.Image = self.process_image(resized_image) + processed_image = self.process_image(resized_image) latents = self.lda.encode_image(image=processed_image).to(device=self.device) processed_caption = self.process_caption(caption=caption) clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)