mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
make device and dtype work in Trainer class
This commit is contained in:
parent
1fa5266f56
commit
ea05f3d327
|
@ -1,53 +0,0 @@
|
|||
[wandb]
|
||||
mode = "offline"
|
||||
entity = "acme"
|
||||
project = "test-ldm-training"
|
||||
|
||||
[models]
|
||||
lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false}
|
||||
text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true}
|
||||
unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true}
|
||||
|
||||
[latent_diffusion]
|
||||
unconditional_sampling_probability = 0.2
|
||||
offset_noise = 0.1
|
||||
|
||||
[training]
|
||||
duration = "1:epoch"
|
||||
seed = 0
|
||||
gpu_index = 0
|
||||
num_epochs = 1
|
||||
batch_size = 1
|
||||
gradient_accumulation = "1:step"
|
||||
clip_grad_norm = 2.0
|
||||
clip_grad_value = 1.0
|
||||
evaluation_interval = "1:epoch"
|
||||
evaluation_seed = 0
|
||||
|
||||
|
||||
[optimizer]
|
||||
optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam"
|
||||
learning_rate = 1e-5
|
||||
betas = [0.9, 0.999]
|
||||
eps = 1e-8
|
||||
weight_decay = 1e-2
|
||||
|
||||
|
||||
[scheduler]
|
||||
|
||||
|
||||
[dropout]
|
||||
dropout_probability = 0.2
|
||||
|
||||
[dataset]
|
||||
hf_repo = "acme/images"
|
||||
revision = "main"
|
||||
|
||||
[checkpointing]
|
||||
# save_folder = "/path/to/ckpts"
|
||||
save_interval = "1:epoch"
|
||||
|
||||
[test_diffusion]
|
||||
prompts = [
|
||||
"A cute cat",
|
||||
]
|
|
@ -1,61 +0,0 @@
|
|||
[wandb]
|
||||
mode = "offline" # "online", "offline", "disabled"
|
||||
entity = "acme"
|
||||
project = "test-textual-inversion"
|
||||
|
||||
[models]
|
||||
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
||||
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
||||
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
||||
|
||||
[latent_diffusion]
|
||||
unconditional_sampling_probability = 0.05
|
||||
offset_noise = 0.1
|
||||
|
||||
[textual_inversion]
|
||||
placeholder_token = "<cat-toy>"
|
||||
initializer_token = "toy"
|
||||
# style_mode = true
|
||||
|
||||
[training]
|
||||
duration = "2000:step"
|
||||
seed = 0
|
||||
gpu_index = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "1:step"
|
||||
evaluation_interval = "250:step"
|
||||
evaluation_seed = 1
|
||||
|
||||
[optimizer]
|
||||
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
||||
learning_rate = 5e-4
|
||||
betas = [0.9, 0.999]
|
||||
eps = 1e-8
|
||||
weight_decay = 1e-2
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
|
||||
[dropout]
|
||||
dropout_probability = 0
|
||||
use_gyro_dropout = false
|
||||
|
||||
[dataset]
|
||||
hf_repo = "acme/images"
|
||||
revision = "main"
|
||||
horizontal_flip = true
|
||||
random_crop = true
|
||||
resize_image_max_size = 512
|
||||
|
||||
[checkpointing]
|
||||
# save_folder = "/path/to/ckpts"
|
||||
save_interval = "250:step"
|
||||
|
||||
[test_diffusion]
|
||||
num_inference_steps = 30
|
||||
use_short_prompts = false
|
||||
prompts = [
|
||||
"<cat-toy>",
|
||||
# "green grass, <cat-toy>"
|
||||
]
|
|
@ -50,9 +50,10 @@ def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValu
|
|||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
device: str = "cpu"
|
||||
dtype: str = "float32"
|
||||
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||
seed: int = 0
|
||||
gpu_index: int = 0
|
||||
batch_size: int = 1
|
||||
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
||||
clip_grad_norm: float | None = None
|
||||
|
|
|
@ -6,8 +6,9 @@ from pathlib import Path
|
|||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
|
||||
from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
|
||||
from torch.autograd import backward
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
@ -298,10 +299,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
|
||||
@cached_property
|
||||
def device(self) -> Device:
|
||||
selected_device = Device(device=f"cuda:{self.config.training.gpu_index}")
|
||||
selected_device = Device(self.config.training.device)
|
||||
logger.info(f"Using device: {selected_device}")
|
||||
return selected_device
|
||||
|
||||
@cached_property
|
||||
def dtype(self) -> DType:
|
||||
dtype = getattr(torch, self.config.training.dtype, None)
|
||||
assert isinstance(dtype, DType), f"Unknown dtype: {self.config.training.dtype}"
|
||||
logger.info(f"Using dtype: {dtype}")
|
||||
return dtype
|
||||
|
||||
@property
|
||||
def parameters(self) -> list[Parameter]:
|
||||
"""Returns a list of all parameters in all models"""
|
||||
|
|
|
@ -4,7 +4,8 @@ train = true
|
|||
[training]
|
||||
duration = "100:epoch"
|
||||
seed = 0
|
||||
gpu_index = 0
|
||||
device = "cpu"
|
||||
dtype = "float32"
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
clip_grad_norm = 1.0
|
||||
|
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -76,12 +75,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(test_device: torch.device) -> MockConfig:
|
||||
if not test_device.type == "cuda":
|
||||
warn("only running on CUDA, skipping")
|
||||
pytest.skip("Skipping test because test_device is not CUDA")
|
||||
def mock_config() -> MockConfig:
|
||||
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
|
||||
config.training.gpu_index = test_device.index
|
||||
return config
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue