mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 07:21:59 +00:00
Make horizontal flipping parametrable in training scripts
This commit is contained in:
parent
18c84c7b72
commit
437fa24368
|
@ -208,6 +208,8 @@ class HuggingfaceDatasetConfig(BaseModel):
|
||||||
hf_repo: str = "finegrain/unsplash-dummy"
|
hf_repo: str = "finegrain/unsplash-dummy"
|
||||||
revision: str = "main"
|
revision: str = "main"
|
||||||
split: str = "train"
|
split: str = "train"
|
||||||
|
horizontal_flip: bool = False
|
||||||
|
random_crop: bool = True
|
||||||
use_verification: bool = False
|
use_verification: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, TypeVar, TypedDict, cast
|
from typing import Any, TypeVar, TypedDict, Callable
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from torchvision.transforms import RandomCrop # type: ignore
|
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
@ -23,6 +23,7 @@ from refiners.training_utils.wandb import WandbLoggable
|
||||||
from refiners.training_utils.trainer import Trainer
|
from refiners.training_utils.trainer import Trainer
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback
|
||||||
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
|
from refiners.training_utils.huggingface_datasets import load_hf_dataset, HuggingfaceDataset
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionConfig(BaseModel):
|
class LatentDiffusionConfig(BaseModel):
|
||||||
|
@ -67,15 +68,25 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
||||||
self.lda = self.trainer.lda
|
self.lda = self.trainer.lda
|
||||||
self.text_encoder = self.trainer.text_encoder
|
self.text_encoder = self.trainer.text_encoder
|
||||||
self.dataset = self.load_huggingface_dataset()
|
self.dataset = self.load_huggingface_dataset()
|
||||||
self.process_image = RandomCrop(size=512) # TODO: make this configurable and add other transforms
|
self.process_image = self.build_image_processor()
|
||||||
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
|
logger.info(f"Loaded {len(self.dataset)} samples from dataset")
|
||||||
|
|
||||||
|
def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
|
||||||
|
# TODO: make this configurable and add other transforms
|
||||||
|
transforms: list[Module] = []
|
||||||
|
if self.config.dataset.random_crop:
|
||||||
|
transforms.append(RandomCrop(size=512))
|
||||||
|
if self.config.dataset.horizontal_flip:
|
||||||
|
transforms.append(RandomHorizontalFlip(p=0.5))
|
||||||
|
if not transforms:
|
||||||
|
return lambda image: image
|
||||||
|
return Compose(transforms)
|
||||||
|
|
||||||
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
|
def load_huggingface_dataset(self) -> HuggingfaceDataset[CaptionImage]:
|
||||||
dataset_config = self.config.dataset
|
dataset_config = self.config.dataset
|
||||||
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
|
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
|
||||||
return cast(
|
return load_hf_dataset(
|
||||||
HuggingfaceDataset[CaptionImage],
|
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
|
||||||
load_hf_dataset(path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
|
def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
|
||||||
|
@ -88,7 +99,7 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
||||||
item = self.dataset[index]
|
item = self.dataset[index]
|
||||||
caption, image = item["caption"], item["image"]
|
caption, image = item["caption"], item["image"]
|
||||||
resized_image = self.resize_image(image=image)
|
resized_image = self.resize_image(image=image)
|
||||||
processed_image: Image.Image = self.process_image(resized_image)
|
processed_image = self.process_image(resized_image)
|
||||||
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
|
||||||
processed_caption = self.process_caption(caption=caption)
|
processed_caption = self.process_caption(caption=caption)
|
||||||
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
|
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
|
||||||
|
|
Loading…
Reference in a new issue