make device and dtype work in Trainer class

This commit is contained in:
limiteinductive 2024-02-06 21:39:38 +00:00 committed by Benjamin Trom
parent 1fa5266f56
commit ea05f3d327
6 changed files with 15 additions and 124 deletions

View file

@ -1,53 +0,0 @@
[wandb]
mode = "offline"
entity = "acme"
project = "test-ldm-training"
[models]
lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false}
text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true}
unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true}
[latent_diffusion]
unconditional_sampling_probability = 0.2
offset_noise = 0.1
[training]
duration = "1:epoch"
seed = 0
gpu_index = 0
num_epochs = 1
batch_size = 1
gradient_accumulation = "1:step"
clip_grad_norm = 2.0
clip_grad_value = 1.0
evaluation_interval = "1:epoch"
evaluation_seed = 0
[optimizer]
optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam"
learning_rate = 1e-5
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
[dropout]
dropout_probability = 0.2
[dataset]
hf_repo = "acme/images"
revision = "main"
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "1:epoch"
[test_diffusion]
prompts = [
"A cute cat",
]

View file

@ -1,61 +0,0 @@
[wandb]
mode = "offline" # "online", "offline", "disabled"
entity = "acme"
project = "test-textual-inversion"
[models]
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
[latent_diffusion]
unconditional_sampling_probability = 0.05
offset_noise = 0.1
[textual_inversion]
placeholder_token = "<cat-toy>"
initializer_token = "toy"
# style_mode = true
[training]
duration = "2000:step"
seed = 0
gpu_index = 0
batch_size = 4
gradient_accumulation = "1:step"
evaluation_interval = "250:step"
evaluation_seed = 1
[optimizer]
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
learning_rate = 5e-4
betas = [0.9, 0.999]
eps = 1e-8
weight_decay = 1e-2
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
[dropout]
dropout_probability = 0
use_gyro_dropout = false
[dataset]
hf_repo = "acme/images"
revision = "main"
horizontal_flip = true
random_crop = true
resize_image_max_size = 512
[checkpointing]
# save_folder = "/path/to/ckpts"
save_interval = "250:step"
[test_diffusion]
num_inference_steps = 30
use_short_prompts = false
prompts = [
"<cat-toy>",
# "green grass, <cat-toy>"
]

View file

@ -50,9 +50,10 @@ def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValu
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
device: str = "cpu"
dtype: str = "float32"
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
seed: int = 0 seed: int = 0
gpu_index: int = 0
batch_size: int = 1 batch_size: int = 1
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP} gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
clip_grad_norm: float | None = None clip_grad_norm: float | None = None

View file

@ -6,8 +6,9 @@ from pathlib import Path
from typing import Any, Callable, Generic, Iterable, TypeVar, cast from typing import Any, Callable, Generic, Iterable, TypeVar, cast
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
from torch.autograd import backward from torch.autograd import backward
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
@ -298,10 +299,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@cached_property @cached_property
def device(self) -> Device: def device(self) -> Device:
selected_device = Device(device=f"cuda:{self.config.training.gpu_index}") selected_device = Device(self.config.training.device)
logger.info(f"Using device: {selected_device}") logger.info(f"Using device: {selected_device}")
return selected_device return selected_device
@cached_property
def dtype(self) -> DType:
dtype = getattr(torch, self.config.training.dtype, None)
assert isinstance(dtype, DType), f"Unknown dtype: {self.config.training.dtype}"
logger.info(f"Using dtype: {dtype}")
return dtype
@property @property
def parameters(self) -> list[Parameter]: def parameters(self) -> list[Parameter]:
"""Returns a list of all parameters in all models""" """Returns a list of all parameters in all models"""

View file

@ -4,7 +4,8 @@ train = true
[training] [training]
duration = "100:epoch" duration = "100:epoch"
seed = 0 seed = 0
gpu_index = 0 device = "cpu"
dtype = "float32"
batch_size = 4 batch_size = 4
gradient_accumulation = "4:step" gradient_accumulation = "4:step"
clip_grad_norm = 1.0 clip_grad_norm = 1.0

View file

@ -2,7 +2,6 @@ from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
from warnings import warn
import pytest import pytest
import torch import torch
@ -76,12 +75,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
@pytest.fixture @pytest.fixture
def mock_config(test_device: torch.device) -> MockConfig: def mock_config() -> MockConfig:
if not test_device.type == "cuda":
warn("only running on CUDA, skipping")
pytest.skip("Skipping test because test_device is not CUDA")
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml") config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
config.training.gpu_index = test_device.index
return config return config