mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 09:08:14 +00:00
make device and dtype work in Trainer class
This commit is contained in:
parent
1fa5266f56
commit
ea05f3d327
|
@ -1,53 +0,0 @@
|
||||||
[wandb]
|
|
||||||
mode = "offline"
|
|
||||||
entity = "acme"
|
|
||||||
project = "test-ldm-training"
|
|
||||||
|
|
||||||
[models]
|
|
||||||
lda = {checkpoint="/path/to/stable-diffusion-1-5/lda.safetensors", train=false}
|
|
||||||
text_encoder = {checkpoint="/path/to/stable-diffusion-1-5/text_encoder.safetensors", train=true}
|
|
||||||
unet = {checkpoint="/path/to/stable-diffusion-1-5/unet.safetensors", train=true}
|
|
||||||
|
|
||||||
[latent_diffusion]
|
|
||||||
unconditional_sampling_probability = 0.2
|
|
||||||
offset_noise = 0.1
|
|
||||||
|
|
||||||
[training]
|
|
||||||
duration = "1:epoch"
|
|
||||||
seed = 0
|
|
||||||
gpu_index = 0
|
|
||||||
num_epochs = 1
|
|
||||||
batch_size = 1
|
|
||||||
gradient_accumulation = "1:step"
|
|
||||||
clip_grad_norm = 2.0
|
|
||||||
clip_grad_value = 1.0
|
|
||||||
evaluation_interval = "1:epoch"
|
|
||||||
evaluation_seed = 0
|
|
||||||
|
|
||||||
|
|
||||||
[optimizer]
|
|
||||||
optimizer = "AdamW" # "AdamW", "AdamW8bit", "Lion8bit", "Prodigy", "SGD", "Adam"
|
|
||||||
learning_rate = 1e-5
|
|
||||||
betas = [0.9, 0.999]
|
|
||||||
eps = 1e-8
|
|
||||||
weight_decay = 1e-2
|
|
||||||
|
|
||||||
|
|
||||||
[scheduler]
|
|
||||||
|
|
||||||
|
|
||||||
[dropout]
|
|
||||||
dropout_probability = 0.2
|
|
||||||
|
|
||||||
[dataset]
|
|
||||||
hf_repo = "acme/images"
|
|
||||||
revision = "main"
|
|
||||||
|
|
||||||
[checkpointing]
|
|
||||||
# save_folder = "/path/to/ckpts"
|
|
||||||
save_interval = "1:epoch"
|
|
||||||
|
|
||||||
[test_diffusion]
|
|
||||||
prompts = [
|
|
||||||
"A cute cat",
|
|
||||||
]
|
|
|
@ -1,61 +0,0 @@
|
||||||
[wandb]
|
|
||||||
mode = "offline" # "online", "offline", "disabled"
|
|
||||||
entity = "acme"
|
|
||||||
project = "test-textual-inversion"
|
|
||||||
|
|
||||||
[models]
|
|
||||||
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
|
|
||||||
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
|
|
||||||
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
|
|
||||||
|
|
||||||
[latent_diffusion]
|
|
||||||
unconditional_sampling_probability = 0.05
|
|
||||||
offset_noise = 0.1
|
|
||||||
|
|
||||||
[textual_inversion]
|
|
||||||
placeholder_token = "<cat-toy>"
|
|
||||||
initializer_token = "toy"
|
|
||||||
# style_mode = true
|
|
||||||
|
|
||||||
[training]
|
|
||||||
duration = "2000:step"
|
|
||||||
seed = 0
|
|
||||||
gpu_index = 0
|
|
||||||
batch_size = 4
|
|
||||||
gradient_accumulation = "1:step"
|
|
||||||
evaluation_interval = "250:step"
|
|
||||||
evaluation_seed = 1
|
|
||||||
|
|
||||||
[optimizer]
|
|
||||||
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
|
|
||||||
learning_rate = 5e-4
|
|
||||||
betas = [0.9, 0.999]
|
|
||||||
eps = 1e-8
|
|
||||||
weight_decay = 1e-2
|
|
||||||
|
|
||||||
[scheduler]
|
|
||||||
scheduler_type = "ConstantLR"
|
|
||||||
update_interval = "1:step"
|
|
||||||
|
|
||||||
[dropout]
|
|
||||||
dropout_probability = 0
|
|
||||||
use_gyro_dropout = false
|
|
||||||
|
|
||||||
[dataset]
|
|
||||||
hf_repo = "acme/images"
|
|
||||||
revision = "main"
|
|
||||||
horizontal_flip = true
|
|
||||||
random_crop = true
|
|
||||||
resize_image_max_size = 512
|
|
||||||
|
|
||||||
[checkpointing]
|
|
||||||
# save_folder = "/path/to/ckpts"
|
|
||||||
save_interval = "250:step"
|
|
||||||
|
|
||||||
[test_diffusion]
|
|
||||||
num_inference_steps = 30
|
|
||||||
use_short_prompts = false
|
|
||||||
prompts = [
|
|
||||||
"<cat-toy>",
|
|
||||||
# "green grass, <cat-toy>"
|
|
||||||
]
|
|
|
@ -50,9 +50,10 @@ def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValu
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
|
device: str = "cpu"
|
||||||
|
dtype: str = "float32"
|
||||||
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
gpu_index: int = 0
|
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
|
||||||
clip_grad_norm: float | None = None
|
clip_grad_norm: float | None = None
|
||||||
|
|
|
@ -6,8 +6,9 @@ from pathlib import Path
|
||||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
|
from torch import Tensor, cuda, device as Device, dtype as DType, get_rng_state, set_rng_state, stack
|
||||||
from torch.autograd import backward
|
from torch.autograd import backward
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -298,10 +299,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
selected_device = Device(device=f"cuda:{self.config.training.gpu_index}")
|
selected_device = Device(self.config.training.device)
|
||||||
logger.info(f"Using device: {selected_device}")
|
logger.info(f"Using device: {selected_device}")
|
||||||
return selected_device
|
return selected_device
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
dtype = getattr(torch, self.config.training.dtype, None)
|
||||||
|
assert isinstance(dtype, DType), f"Unknown dtype: {self.config.training.dtype}"
|
||||||
|
logger.info(f"Using dtype: {dtype}")
|
||||||
|
return dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> list[Parameter]:
|
def parameters(self) -> list[Parameter]:
|
||||||
"""Returns a list of all parameters in all models"""
|
"""Returns a list of all parameters in all models"""
|
||||||
|
|
|
@ -4,7 +4,8 @@ train = true
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
seed = 0
|
seed = 0
|
||||||
gpu_index = 0
|
device = "cpu"
|
||||||
|
dtype = "float32"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
gradient_accumulation = "4:step"
|
gradient_accumulation = "4:step"
|
||||||
clip_grad_norm = 1.0
|
clip_grad_norm = 1.0
|
||||||
|
|
|
@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -76,12 +75,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config(test_device: torch.device) -> MockConfig:
|
def mock_config() -> MockConfig:
|
||||||
if not test_device.type == "cuda":
|
|
||||||
warn("only running on CUDA, skipping")
|
|
||||||
pytest.skip("Skipping test because test_device is not CUDA")
|
|
||||||
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
|
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml")
|
||||||
config.training.gpu_index = test_device.index
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue