diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index 73fa359..32bcf16 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -10,6 +10,7 @@ import refiners.fluxion.layers as fl from refiners.fluxion.utils import save_to_safetensors from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets from refiners.training_utils.callback import Callback +from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig from refiners.training_utils.latent_diffusion import ( FinetuneLatentDiffusionConfig, LatentDiffusionConfig, @@ -50,6 +51,7 @@ class TriggerPhraseDataset(TextEmbeddingLatentsDataset): class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig): + dataset: HuggingfaceDatasetConfig latent_diffusion: LatentDiffusionConfig lora: LoraConfig diff --git a/scripts/training/finetune-ldm-textual-inversion.py b/scripts/training/finetune-ldm-textual-inversion.py index 66bf44e..18142b3 100644 --- a/scripts/training/finetune-ldm-textual-inversion.py +++ b/scripts/training/finetune-ldm-textual-inversion.py @@ -11,6 +11,7 @@ from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExten from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.training_utils.callback import Callback +from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig from refiners.training_utils.latent_diffusion import ( FinetuneLatentDiffusionConfig, LatentDiffusionConfig, @@ -112,6 +113,7 @@ class TextualInversionConfig(BaseModel): class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig): + dataset: HuggingfaceDatasetConfig latent_diffusion: LatentDiffusionConfig textual_inversion: TextualInversionConfig diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index fdf979a..d84750f 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -206,17 +206,6 @@ class WandbConfig(BaseModel): notes: str | None = None -class HuggingfaceDatasetConfig(BaseModel): - hf_repo: str = "finegrain/unsplash-dummy" - revision: str = "main" - split: str = "train" - horizontal_flip: bool = False - random_crop: bool = True - use_verification: bool = False - resize_image_min_size: int = 512 - resize_image_max_size: int = 576 - - class CheckpointingConfig(BaseModel): save_folder: Path | None = None save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH} @@ -237,7 +226,6 @@ class BaseConfig(BaseModel): optimizer: OptimizerConfig scheduler: SchedulerConfig dropout: DropoutConfig - dataset: HuggingfaceDatasetConfig checkpointing: CheckpointingConfig @classmethod diff --git a/src/refiners/training_utils/huggingface_datasets.py b/src/refiners/training_utils/huggingface_datasets.py index 5b41a33..6433323 100644 --- a/src/refiners/training_utils/huggingface_datasets.py +++ b/src/refiners/training_utils/huggingface_datasets.py @@ -1,6 +1,7 @@ from typing import Any, Generic, Protocol, TypeVar, cast from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore +from pydantic import BaseModel # type: ignore __all__ = ["load_hf_dataset", "HuggingfaceDataset"] @@ -22,3 +23,14 @@ def load_hf_dataset( verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode) return cast(HuggingfaceDataset[Any], dataset) + + +class HuggingfaceDatasetConfig(BaseModel): + hf_repo: str = "finegrain/unsplash-dummy" + revision: str = "main" + split: str = "train" + horizontal_flip: bool = False + random_crop: bool = True + use_verification: bool = False + resize_image_min_size: int = 512 + resize_image_max_size: int = 576 diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 200daaf..dff85fc 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -23,7 +23,7 @@ from refiners.foundationals.latent_diffusion.schedulers import DDPM from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.training_utils.callback import Callback from refiners.training_utils.config import BaseConfig -from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, load_hf_dataset +from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset from refiners.training_utils.trainer import Trainer from refiners.training_utils.wandb import WandbLoggable @@ -44,6 +44,7 @@ class TestDiffusionConfig(BaseModel): class FinetuneLatentDiffusionConfig(BaseConfig): + dataset: HuggingfaceDatasetConfig latent_diffusion: LatentDiffusionConfig test_diffusion: TestDiffusionConfig