mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Make horizontal flipping parametrable in training scripts
This commit is contained in:
parent
18c84c7b72
commit
437fa24368
|
@ -208,6 +208,8 @@ class HuggingfaceDatasetConfig(BaseModel):
|
|||
hf_repo: str = "finegrain/unsplash-dummy"
|
||||
revision: str = "main"
|
||||
split: str = "train"
|
||||
horizontal_flip: bool = False
|
||||
random_crop: bool = True
|
||||
use_verification: bool = False
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, TypeVar, TypedDict, cast
|
||||
from typing import Any, TypeVar, TypedDict, Callable
|
||||
from pydantic import BaseModel
|
||||
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
||||
from loguru import logger
|
||||
from torch.utils.data import Dataset
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from torchvision.transforms import RandomCrop # type: ignore
|
||||
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore
|
||||
import refiners.fluxion.layers as fl
|
||||
from PIL import Image
|
||||
from functools import cached_property
|
||||
|
@ -23,6 +23,7 @@ from refiners.training_utils.wandb import WandbLoggable
|
|||
from refiners.training_utils.trainer import Trainer
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
class LatentDiffusionConfig(BaseModel):
|
||||
|
@ -67,15 +68,25 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
|||
self.lda = self.trainer.lda
|
||||
self.text_encoder = self.trainer.text_encoder
|
||||
self.dataset = self.load_huggingface_dataset()
|
||||
self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms
|
||||
self.process_image = self.build_image_processor()
|
||||
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
|
||||
|
||||
def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
|
||||
# TODO: make this configurable and add other transforms
|
||||
transforms: list[Module] = []
|
||||
if self.config.dataset.random_crop:
|
||||
transforms.append(RandomCrop(size=512))
|
||||
if self.config.dataset.horizontal_flip:
|
||||
transforms.append(RandomHorizontalFlip(p=0.5))
|
||||
if not transforms:
|
||||
return lambda image: image
|
||||
return Compose(transforms)
|
||||
|
||||
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
|
||||
dataset_config = self.config.dataset
|
||||
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
|
||||
return cast(
|
||||
HuggingfaceDataset[CaptionImage],
|
||||
load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split),
|
||||
return load_hf_dataset(
|
||||
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
|
||||
)
|
||||
|
||||
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
|
||||
|
@ -88,7 +99,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
|||
item = self.dataset[index]
|
||||
caption, image = item["caption"], item["image"]
|
||||
resized_image = self.resize_image(image=image)
|
||||
processed_image: Image.Image = self.process_image(resized_image)
|
||||
processed_image = self.process_image(resized_image)
|
||||
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
||||
processed_caption = self.process_caption(caption=caption)
|
||||
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
|
||||
|
|
Loading…
Reference in a new issue