Make horizontal flipping parametrable in training scripts

This commit is contained in:
Doryan Kaced 2023-08-30 10:52:23 +02:00
parent 18c84c7b72
commit 437fa24368
2 changed files with 20 additions and 7 deletions

View file

@ -208,6 +208,8 @@ class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str = "finegrain/unsplash-dummy" hf_repo: str = "finegrain/unsplash-dummy"
revision: str = "main" revision: str = "main"
split: str = "train" split: str = "train"
horizontal_flip: bool = False
random_crop: bool = True
use_verification: bool = False use_verification: bool = False

View file

@ -1,11 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, TypeVar, TypedDict, cast from typing import Any, TypeVar, TypedDict, Callable
from pydantic import BaseModel from pydantic import BaseModel
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
from loguru import logger from loguru import logger
from torch.utils.data import Dataset from torch.utils.data import Dataset
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from torchvision.transforms import RandomCrop # type: ignore from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from PIL import Image from PIL import Image
from functools import cached_property from functools import cached_property
@ -23,6 +23,7 @@ from refiners.training_utils.wandb import WandbLoggable
from refiners.training_utils.trainer import Trainer from refiners.training_utils.trainer import Trainer
from refiners.training_utils.callback import Callback from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
from torch.nn import Module
class LatentDiffusionConfig(BaseModel): class LatentDiffusionConfig(BaseModel):
@ -67,15 +68,25 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
self.lda = self.trainer.lda self.lda = self.trainer.lda
self.text_encoder = self.trainer.text_encoder self.text_encoder = self.trainer.text_encoder
self.dataset = self.load_huggingface_dataset() self.dataset = self.load_huggingface_dataset()
self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms self.process_image = self.build_image_processor()
logger.info(f"Loaded {len(self.dataset)} samples from dataset") logger.info(f"Loaded {len(self.dataset)} samples from dataset")
def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
# TODO: make this configurable and add other transforms
transforms: list[Module] = []
if self.config.dataset.random_crop:
transforms.append(RandomCrop(size=512))
if self.config.dataset.horizontal_flip:
transforms.append(RandomHorizontalFlip(p=0.5))
if not transforms:
return lambda image: image
return Compose(transforms)
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]: def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
dataset_config = self.config.dataset dataset_config = self.config.dataset
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}") logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
return cast( return load_hf_dataset(
HuggingfaceDataset[CaptionImage], path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split),
) )
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image: def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
@ -88,7 +99,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
item = self.dataset[index] item = self.dataset[index]
caption, image = item["caption"], item["image"] caption, image = item["caption"], item["image"]
resized_image = self.resize_image(image=image) resized_image = self.resize_image(image=image)
processed_image: Image.Image = self.process_image(resized_image) processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device) latents = self.lda.encode_image(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption) processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device) clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)