mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
remove huggingface datasets from default config
This commit is contained in:
parent
0f560437bc
commit
6a1fac876b
|
@ -10,6 +10,7 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
|
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback
|
||||||
|
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
||||||
from refiners.training_utils.latent_diffusion import (
|
from refiners.training_utils.latent_diffusion import (
|
||||||
FinetuneLatentDiffusionConfig,
|
FinetuneLatentDiffusionConfig,
|
||||||
LatentDiffusionConfig,
|
LatentDiffusionConfig,
|
||||||
|
@ -50,6 +51,7 @@ class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
||||||
|
|
||||||
|
|
||||||
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||||
|
dataset: HuggingfaceDatasetConfig
|
||||||
latent_diffusion: LatentDiffusionConfig
|
latent_diffusion: LatentDiffusionConfig
|
||||||
lora: LoraConfig
|
lora: LoraConfig
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from refiners.foundationals.clip.concepts import ConceptExtender, EmbeddingExten
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback
|
||||||
|
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
|
||||||
from refiners.training_utils.latent_diffusion import (
|
from refiners.training_utils.latent_diffusion import (
|
||||||
FinetuneLatentDiffusionConfig,
|
FinetuneLatentDiffusionConfig,
|
||||||
LatentDiffusionConfig,
|
LatentDiffusionConfig,
|
||||||
|
@ -112,6 +113,7 @@ class TextualInversionConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
||||||
|
dataset: HuggingfaceDatasetConfig
|
||||||
latent_diffusion: LatentDiffusionConfig
|
latent_diffusion: LatentDiffusionConfig
|
||||||
textual_inversion: TextualInversionConfig
|
textual_inversion: TextualInversionConfig
|
||||||
|
|
||||||
|
|
|
@ -206,17 +206,6 @@ class WandbConfig(BaseModel):
|
||||||
notes: str | None = None
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDatasetConfig(BaseModel):
|
|
||||||
hf_repo: str = "finegrain/unsplash-dummy"
|
|
||||||
revision: str = "main"
|
|
||||||
split: str = "train"
|
|
||||||
horizontal_flip: bool = False
|
|
||||||
random_crop: bool = True
|
|
||||||
use_verification: bool = False
|
|
||||||
resize_image_min_size: int = 512
|
|
||||||
resize_image_max_size: int = 576
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointingConfig(BaseModel):
|
class CheckpointingConfig(BaseModel):
|
||||||
save_folder: Path | None = None
|
save_folder: Path | None = None
|
||||||
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
|
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
|
||||||
|
@ -237,7 +226,6 @@ class BaseConfig(BaseModel):
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
scheduler: SchedulerConfig
|
scheduler: SchedulerConfig
|
||||||
dropout: DropoutConfig
|
dropout: DropoutConfig
|
||||||
dataset: HuggingfaceDatasetConfig
|
|
||||||
checkpointing: CheckpointingConfig
|
checkpointing: CheckpointingConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Any, Generic, Protocol, TypeVar, cast
|
from typing import Any, Generic, Protocol, TypeVar, cast
|
||||||
|
|
||||||
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
|
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
|
||||||
|
from pydantic import BaseModel # type: ignore
|
||||||
|
|
||||||
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
||||||
|
|
||||||
|
@ -22,3 +23,14 @@ def load_hf_dataset(
|
||||||
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
|
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
|
||||||
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
|
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
|
||||||
return cast(HuggingfaceDataset[Any], dataset)
|
return cast(HuggingfaceDataset[Any], dataset)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceDatasetConfig(BaseModel):
|
||||||
|
hf_repo: str = "finegrain/unsplash-dummy"
|
||||||
|
revision: str = "main"
|
||||||
|
split: str = "train"
|
||||||
|
horizontal_flip: bool = False
|
||||||
|
random_crop: bool = True
|
||||||
|
use_verification: bool = False
|
||||||
|
resize_image_min_size: int = 512
|
||||||
|
resize_image_max_size: int = 576
|
||||||
|
|
|
@ -23,7 +23,7 @@ from refiners.foundationals.latent_diffusion.schedulers import DDPM
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, load_hf_dataset
|
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset
|
||||||
from refiners.training_utils.trainer import Trainer
|
from refiners.training_utils.trainer import Trainer
|
||||||
from refiners.training_utils.wandb import WandbLoggable
|
from refiners.training_utils.wandb import WandbLoggable
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ class TestDiffusionConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class FinetuneLatentDiffusionConfig(BaseConfig):
|
class FinetuneLatentDiffusionConfig(BaseConfig):
|
||||||
|
dataset: HuggingfaceDatasetConfig
|
||||||
latent_diffusion: LatentDiffusionConfig
|
latent_diffusion: LatentDiffusionConfig
|
||||||
test_diffusion: TestDiffusionConfig
|
test_diffusion: TestDiffusionConfig
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue