diff --git a/configs/finetune-ldm.toml b/configs/finetune-ldm.toml deleted file mode 100644 index 22ff820..0000000 --- a/configs/finetune-ldm.toml +++ /dev/null @@ -1,53 +0,0 @@ -[wandb] -mode = "offline" -entity = "acme" -project = "test-ldm-training" - -[models] -lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false} -text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true} -unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true} - -[latent_diffusion] -unconditional_sampling_probability = 0.2 -offset_noise = 0.1 - -[training] -duration = "1:epoch" -seed = 0 -gpu_index = 0 -num_epochs = 1 -batch_size = 1 -gradient_accumulation = "1:step" -clip_grad_norm = 2.0 -clip_grad_value = 1.0 -evaluation_interval = "1:epoch" -evaluation_seed = 0 - - -[optimizer] -optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam" -learning_rate = 1e-5 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - - -[scheduler] - - -[dropout] -dropout_probability = 0.2 - -[dataset] -hf_repo = "acme/images" -revision = "main" - -[checkpointing] -# save_folder = "/path/to/ckpts" -save_interval = "1:epoch" - -[test_diffusion] -prompts = [ - "A cute cat", -] diff --git a/configs/finetune-textual-inversion.toml b/configs/finetune-textual-inversion.toml deleted file mode 100644 index b82c34c..0000000 --- a/configs/finetune-textual-inversion.toml +++ /dev/null @@ -1,61 +0,0 @@ -[wandb] -mode = "offline" # "online", "offline", "disabled" -entity = "acme" -project = "test-textual-inversion" - -[models] -unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"} -text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"} -lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"} - -[latent_diffusion] -unconditional_sampling_probability = 0.05 -offset_noise = 0.1 - -[textual_inversion] -placeholder_token = "" -initializer_token = "toy" -# style_mode = true - -[training] -duration = "2000:step" -seed = 0 -gpu_index = 0 -batch_size = 4 -gradient_accumulation = "1:step" -evaluation_interval = "250:step" -evaluation_seed = 1 - -[optimizer] -optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 5e-4 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" - -[dropout] -dropout_probability = 0 -use_gyro_dropout = false - -[dataset] -hf_repo = "acme/images" -revision = "main" -horizontal_flip = true -random_crop = true -resize_image_max_size = 512 - -[checkpointing] -# save_folder = "/path/to/ckpts" -save_interval = "250:step" - -[test_diffusion] -num_inference_steps = 30 -use_short_prompts = false -prompts = [ - "", - # "green grass, " -] diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 8af1ddf..c484b6a 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -50,9 +50,10 @@ def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValu class TrainingConfig(BaseModel): + device: str = "cpu" + dtype: str = "float32" duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} seed: int = 0 - gpu_index: int = 0 batch_size: int = 1 gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP} clip_grad_norm: float | None = None diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 021a168..8344a99 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -6,8 +6,9 @@ from pathlib import Path from typing import Any, Callable, Generic, Iterable, TypeVar, cast import numpy as np +import torch from loguru import logger -from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack +from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack from torch.autograd import backward from torch.nn import Parameter from torch.optim import Optimizer @@ -298,10 +299,17 @@ class Trainer(Generic[ConfigType, Batch], ABC): @cached_property def device(self) -> Device: - selected_device = Device(device=f"cuda:{self.config.training.gpu_index}") + selected_device = Device(self.config.training.device) logger.info(f"Using device: {selected_device}") return selected_device + @cached_property + def dtype(self) -> DType: + dtype = getattr(torch, self.config.training.dtype, None) + assert isinstance(dtype, DType), f"Unknown dtype: {self.config.training.dtype}" + logger.info(f"Using dtype: {dtype}") + return dtype + @property def parameters(self) -> list[Parameter]: """Returns a list of all parameters in all models""" diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 20c2f3a..d96a820 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -4,7 +4,8 @@ train = true [training] duration = "100:epoch" seed = 0 -gpu_index = 0 +device = "cpu" +dtype = "float32" batch_size = 4 gradient_accumulation = "4:step" clip_grad_norm = 1.0 diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 35bc8e4..ae88eef 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path from typing import cast -from warnings import warn import pytest import torch @@ -76,12 +75,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): @pytest.fixture -def mock_config(test_device: torch.device) -> MockConfig: - if not test_device.type == "cuda": - warn("only running on CUDA, skipping") - pytest.skip("Skipping test because test_device is not CUDA") +def mock_config() -> MockConfig: config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml") - config.training.gpu_index = test_device.index return config