mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
remove huggingface datasets from default config
This commit is contained in:
parent
0f560437bc
commit
6a1fac876b
|
@ -10,6 +10,7 @@ import refiners.fluxion.layers as fl
|
|||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
||||
from refiners.training_utils.latent_diffusion import (
|
||||
FinetuneLatentDiffusionConfig,
|
||||
LatentDiffusionConfig,
|
||||
|
@ -50,6 +51,7 @@ class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
|||
|
||||
|
||||
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||
dataset: HuggingfaceDatasetConfig
|
||||
latent_diffusion: LatentDiffusionConfig
|
||||
lora: LoraConfig
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExten
|
|||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
||||
from refiners.training_utils.latent_diffusion import (
|
||||
FinetuneLatentDiffusionConfig,
|
||||
LatentDiffusionConfig,
|
||||
|
@ -112,6 +113,7 @@ class TextualInversionConfig(BaseModel):
|
|||
|
||||
|
||||
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||
dataset: HuggingfaceDatasetConfig
|
||||
latent_diffusion: LatentDiffusionConfig
|
||||
textual_inversion: TextualInversionConfig
|
||||
|
||||
|
|
|
@ -206,17 +206,6 @@ class WandbConfig(BaseModel):
|
|||
notes: str | None = None
|
||||
|
||||
|
||||
class HuggingfaceDatasetConfig(BaseModel):
|
||||
hf_repo: str = "finegrain/unsplash-dummy"
|
||||
revision: str = "main"
|
||||
split: str = "train"
|
||||
horizontal_flip: bool = False
|
||||
random_crop: bool = True
|
||||
use_verification: bool = False
|
||||
resize_image_min_size: int = 512
|
||||
resize_image_max_size: int = 576
|
||||
|
||||
|
||||
class CheckpointingConfig(BaseModel):
|
||||
save_folder: Path | None = None
|
||||
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
|
||||
|
@ -237,7 +226,6 @@ class BaseConfig(BaseModel):
|
|||
optimizer: OptimizerConfig
|
||||
scheduler: SchedulerConfig
|
||||
dropout: DropoutConfig
|
||||
dataset: HuggingfaceDatasetConfig
|
||||
checkpointing: CheckpointingConfig
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Generic, Protocol, TypeVar, cast
|
||||
|
||||
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
||||
|
||||
|
@ -22,3 +23,14 @@ def load_hf_dataset(
|
|||
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
|
||||
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
|
||||
return cast(HuggingfaceDataset[Any], dataset)
|
||||
|
||||
|
||||
class HuggingfaceDatasetConfig(BaseModel):
|
||||
hf_repo: str = "finegrain/unsplash-dummy"
|
||||
revision: str = "main"
|
||||
split: str = "train"
|
||||
horizontal_flip: bool = False
|
||||
random_crop: bool = True
|
||||
use_verification: bool = False
|
||||
resize_image_min_size: int = 512
|
||||
resize_image_max_size: int = 576
|
||||
|
|
|
@ -23,7 +23,7 @@ from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, load_hf_dataset
|
||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset
|
||||
from refiners.training_utils.trainer import Trainer
|
||||
from refiners.training_utils.wandb import WandbLoggable
|
||||
|
||||
|
@ -44,6 +44,7 @@ class TestDiffusionConfig(BaseModel):
|
|||
|
||||
|
||||
class FinetuneLatentDiffusionConfig(BaseConfig):
|
||||
dataset: HuggingfaceDatasetConfig
|
||||
latent_diffusion: LatentDiffusionConfig
|
||||
test_diffusion: TestDiffusionConfig
|
||||
|
||||
|
|
Loading…
Reference in a new issue