remove huggingface datasets from default config

This commit is contained in:
limiteinductive 2023-12-15 17:12:54 +01:00 committed by Benjamin Trom
parent 0f560437bc
commit 6a1fac876b
5 changed files with 18 additions and 13 deletions

View file

@ -10,6 +10,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
@ -50,6 +51,7 @@ class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig

View file

@ -11,6 +11,7 @@ from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExten
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
@ -112,6 +113,7 @@ class TextualInversionConfig(BaseModel):
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
textual_inversion: TextualInversionConfig

View file

@ -206,17 +206,6 @@ class WandbConfig(BaseModel):
notes: str | None = None
class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str = "finegrain/unsplash-dummy"
revision: str = "main"
split: str = "train"
horizontal_flip: bool = False
random_crop: bool = True
use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576
class CheckpointingConfig(BaseModel):
save_folder: Path | None = None
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
@ -237,7 +226,6 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig
scheduler: SchedulerConfig
dropout: DropoutConfig
dataset: HuggingfaceDatasetConfig
checkpointing: CheckpointingConfig
@classmethod

View file

@ -1,6 +1,7 @@
from typing import Any, Generic, Protocol, TypeVar, cast
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
from pydantic import BaseModel # type: ignore
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
@ -22,3 +23,14 @@ def load_hf_dataset(
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
return cast(HuggingfaceDataset[Any], dataset)
class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str = "finegrain/unsplash-dummy"
revision: str = "main"
split: str = "train"
horizontal_flip: bool = False
random_crop: bool = True
use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576

View file

@ -23,7 +23,7 @@ from refiners.foundationals.latent_diffusion.schedulers import DDPM
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, load_hf_dataset
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.wandb import WandbLoggable
@ -44,6 +44,7 @@ class TestDiffusionConfig(BaseModel):
class FinetuneLatentDiffusionConfig(BaseConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
test_diffusion: TestDiffusionConfig