Make horizontal flipping parametrable in training scripts

This commit is contained in:
Doryan Kaced 2023-08-30 10:52:23 +02:00
parent 18c84c7b72
commit 437fa24368
2 changed files with 20 additions and 7 deletions

View file

@ -208,6 +208,8 @@ class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str = "finegrain/unsplash-dummy"
revision: str = "main"
split: str = "train"
horizontal_flip: bool = False
random_crop: bool = True
use_verification: bool = False

View file

@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Any, TypeVar, TypedDict, cast
from typing import Any, TypeVar, TypedDict, Callable
from pydantic import BaseModel
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
from loguru import logger
from torch.utils.data import Dataset
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from torchvision.transforms import RandomCrop # type: ignore
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore
import refiners.fluxion.layers as fl
from PIL import Image
from functools import cached_property
@ -23,6 +23,7 @@ from refiners.training_utils.wandb import WandbLoggable
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
from torch.nn import Module
class LatentDiffusionConfig(BaseModel):
@ -67,15 +68,25 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
self.lda = self.trainer.lda
self.text_encoder = self.trainer.text_encoder
self.dataset = self.load_huggingface_dataset()
self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms
self.process_image = self.build_image_processor()
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
# TODO: make this configurable and add other transforms
transforms: list[Module] = []
if self.config.dataset.random_crop:
transforms.append(RandomCrop(size=512))
if self.config.dataset.horizontal_flip:
transforms.append(RandomHorizontalFlip(p=0.5))
if not transforms:
return lambda image: image
return Compose(transforms)
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
dataset_config = self.config.dataset
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
return cast(
HuggingfaceDataset[CaptionImage],
load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split),
return load_hf_dataset(
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
)
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
@ -88,7 +99,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
item = self.dataset[index]
caption, image = item["caption"], item["image"]
resized_image = self.resize_image(image=image)
processed_image: Image.Image = self.process_image(resized_image)
processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)