make device and dtype work in Trainer class

This commit is contained in:
limiteinductive 2024-02-06 21:39:38 +00:00 committed by Benjamin Trom
parent 1fa5266f56
commit ea05f3d327
6 changed files with 15 additions and 124 deletions

View file

@ -1,53 +0,0 @@
[wandb]
mode = "offline"
entity = "acme"
project = "test-ldm-training"
[models]
lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false}
text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true}
unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true}
[latent_diffusion]
unconditional_sampling_probability = 0.2
offset_noise = 0.1
[training]
duration = "1:epoch"
seed = 0
gpu_index = 0
num_epochs = 1
batch_size = 1
gradient_accumulation = "1:step"
clip_grad_norm = 2.0
clip_grad_value = 1.0
evaluation_interval = "1:epoch"
evaluation_seed = 0
[optimizer]
optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam"
learning_rate = 1e-5
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
[dropout]
dropout_probability = 0.2
[dataset]
hf_repo = "acme/images"
revision = "main"
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "1:epoch"
[test_diffusion]
prompts = [
"A cute cat",
]

View file

@ -1,61 +0,0 @@
[wandb]
mode = "offline" # "online", "offline", "disabled"
entity = "acme"
project = "test-textual-inversion"
[models]
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
[latent_diffusion]
unconditional_sampling_probability = 0.05
offset_noise = 0.1
[textual_inversion]
placeholder_token = "<cat-toy>"
initializer_token = "toy"
# style_mode = true
[training]
duration = "2000:step"
seed = 0
gpu_index = 0
batch_size = 4
gradient_accumulation = "1:step"
evaluation_interval = "250:step"
evaluation_seed = 1
[optimizer]
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
learning_rate = 5e-4
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
[dropout]
dropout_probability = 0
use_gyro_dropout = false
[dataset]
hf_repo = "acme/images"
revision = "main"
horizontal_flip = true
random_crop = true
resize_image_max_size = 512
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "250:step"
[test_diffusion]
num_inference_steps = 30
use_short_prompts = false
prompts = [
"<cat-toy>",
# "green grass, <cat-toy>"
]

View file

@ -50,9 +50,10 @@ def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValu
class TrainingConfig(BaseModel):
device: str = "cpu"
dtype: str = "float32"
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
seed: int = 0
gpu_index: int = 0
batch_size: int = 1
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
clip_grad_norm: float | None = None

View file

@ -6,8 +6,9 @@ from pathlib import Path
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
import numpy as np
import torch
from loguru import logger
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
from torch.autograd import backward
from torch.nn import Parameter
from torch.optim import Optimizer
@ -298,10 +299,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@cached_property
def device(self) -> Device:
selected_device = Device(device=f"cuda:{self.config.training.gpu_index}")
selected_device = Device(self.config.training.device)
logger.info(f"Using device: {selected_device}")
return selected_device
@cached_property
def dtype(self) -> DType:
dtype = getattr(torch, self.config.training.dtype, None)
assert isinstance(dtype, DType), f"Unknown dtype: {self.config.training.dtype}"
logger.info(f"Using dtype: {dtype}")
return dtype
@property
def parameters(self) -> list[Parameter]:
"""Returns a list of all parameters in all models"""

View file

@ -4,7 +4,8 @@ train = true
[training]
duration = "100:epoch"
seed = 0
gpu_index = 0
device = "cpu"
dtype = "float32"
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0

View file

@ -2,7 +2,6 @@ from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import cast
from warnings import warn
import pytest
import torch
@ -76,12 +75,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
@pytest.fixture
def mock_config(test_device: torch.device) -> MockConfig:
if not test_device.type == "cuda":
warn("only running on CUDA, skipping")
pytest.skip("Skipping test because test_device is not CUDA")
def mock_config() -> MockConfig:
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
config.training.gpu_index = test_device.index
return config