change TimeValue to a dataclass

This commit is contained in:
limiteinductive 2024-03-19 13:20:40 +00:00 committed by Benjamin Trom
parent b8fae60d38
commit 6a72943ff7
7 changed files with 85 additions and 49 deletions

View file

@ -220,13 +220,14 @@ Example:
```python ```python
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers
from refiners.training_utils.common import TimeUnit, TimeValue
class AutoencoderConfig(BaseConfig): class AutoencoderConfig(BaseConfig):
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size. # Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
epoch_size: int = 2048 epoch_size: int = 2048
training = TrainingConfig( training = TrainingConfig(
duration="1000:epoch", duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
batch_size=32, batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu", device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32" dtype="float32"
@ -335,11 +336,11 @@ We can also evaluate the model using the `compute_evaluation` method.
```python ```python
training = TrainingConfig( training = TrainingConfig(
duration="1000:epoch", duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
batch_size=32, batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu", device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32", dtype="float32",
evaluation_interval="50:epoch" # We set the evaluation to be done every 10 epochs evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH),
) )
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
@ -459,6 +460,8 @@ You can train this toy model using the code below:
import torch import torch
from loguru import logger from loguru import logger
from PIL import Image from PIL import Image
from torch.nn import functional as F
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import image_to_tensor, tensor_to_image from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from refiners.training_utils import ( from refiners.training_utils import (
@ -476,7 +479,7 @@ You can train this toy model using the code below:
register_callback, register_callback,
register_model, register_model,
) )
from torch.nn import functional as F from refiners.training_utils.common import TimeUnit, TimeValue
class ConvBlock(fl.Chain): class ConvBlock(fl.Chain):
@ -487,7 +490,7 @@ You can train this toy model using the code below:
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
groups=min(in_channels, out_channels) groups=min(in_channels, out_channels),
), ),
fl.LayerNorm2d(out_channels), fl.LayerNorm2d(out_channels),
fl.SiLU(), fl.SiLU(),
@ -576,9 +579,7 @@ You can train this toy model using the code below:
random.seed(seed) random.seed(seed)
while True: while True:
rectangle = Image.new( rectangle = Image.new("L", (random.randint(1, size), random.randint(1, size)), color=255)
"L", (random.randint(1, size), random.randint(1, size)), color=255
)
mask = Image.new("L", (size, size)) mask = Image.new("L", (size, size))
mask.paste( mask.paste(
rectangle, rectangle,
@ -627,11 +628,11 @@ You can train this toy model using the code below:
) )
training = TrainingConfig( training = TrainingConfig(
duration="1000:epoch", # type: ignore duration=TimeValue(number=1000, unit=TimeUnit.EPOCH),
batch_size=32, batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu", device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32", dtype="float32",
evaluation_interval="50:epoch", # type: ignore evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH),
) )
optimizer = OptimizerConfig( optimizer = OptimizerConfig(
@ -639,9 +640,7 @@ You can train this toy model using the code below:
learning_rate=1e-4, learning_rate=1e-4,
) )
lr_scheduler = LRSchedulerConfig( lr_scheduler = LRSchedulerConfig(type=LRSchedulerType.CONSTANT_LR)
type=LRSchedulerType.CONSTANT_LR
)
config = AutoencoderConfig( config = AutoencoderConfig(
training=training, training=training,
@ -672,9 +671,7 @@ You can train this toy model using the code below:
return Autoencoder() return Autoencoder()
def compute_loss(self, batch: Batch) -> torch.Tensor: def compute_loss(self, batch: Batch) -> torch.Tensor:
x_reconstructed = self.autoencoder.decoder( x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image))
self.autoencoder.encoder(batch.image)
)
return F.binary_cross_entropy(x_reconstructed, batch.image) return F.binary_cross_entropy(x_reconstructed, batch.image)
def compute_evaluation(self) -> None: def compute_evaluation(self) -> None:
@ -687,14 +684,14 @@ You can train this toy model using the code below:
x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(mask)) x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(mask))
loss = F.mse_loss(x_reconstructed, mask) loss = F.mse_loss(x_reconstructed, mask)
validation_losses.append(loss.detach().cpu().item()) validation_losses.append(loss.detach().cpu().item())
grid.append((tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))) grid.append((tensor_to_image(mask), tensor_to_image((x_reconstructed > 0.5).float())))
mean_loss = sum(validation_losses) / len(validation_losses) mean_loss = sum(validation_losses) / len(validation_losses)
logger.info(f"Mean validation loss: {mean_loss}, epoch: {self.clock.epoch}") logger.info(f"Mean validation loss: {mean_loss}, epoch: {self.clock.epoch}")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
_, axes = plt.subplots(4, 2, figsize=(8, 16)) # type: ignore _, axes = plt.subplots(4, 2, figsize=(8, 16)) # type: ignore
for i, (mask, reconstructed) in enumerate(grid): for i, (mask, reconstructed) in enumerate(grid):
axes[i, 0].imshow(mask, cmap="gray") axes[i, 0].imshow(mask, cmap="gray")

View file

@ -48,11 +48,11 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@cached_property @cached_property
def unit_to_steps(self) -> dict[TimeUnit, int]: def unit_to_steps(self) -> dict[TimeUnit, int]:
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1 iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation.unit == TimeUnit.EPOCH else 1
return { return {
TimeUnit.STEP: 1, TimeUnit.STEP: 1,
TimeUnit.EPOCH: self.num_batches_per_epoch, TimeUnit.EPOCH: self.num_batches_per_epoch,
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor, TimeUnit.ITERATION: self.gradient_accumulation.number * iteration_factor,
} }
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
@ -62,7 +62,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
return steps // self.unit_to_steps[unit] return steps // self.unit_to_steps[unit]
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
number, unit = time_value["number"], time_value["unit"] number, unit = time_value.number, time_value.unit
steps = self.convert_time_unit_to_steps(number=number, unit=unit) steps = self.convert_time_unit_to_steps(number=number, unit=unit)
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
@ -81,13 +81,13 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@cached_property @cached_property
def num_step_per_iteration(self) -> int: def num_step_per_iteration(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_unit_to_steps(
number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"] number=self.gradient_accumulation.number, unit=self.gradient_accumulation.unit
) )
@cached_property @cached_property
def num_step_per_evaluation(self) -> int: def num_step_per_evaluation(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
) )
def reset(self) -> None: def reset(self) -> None:
@ -113,13 +113,13 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@cached_property @cached_property
def evaluation_interval_steps(self) -> int: def evaluation_interval_steps(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_unit_to_steps(
number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
) )
@cached_property @cached_property
def lr_scheduler_interval_steps(self) -> int: def lr_scheduler_interval_steps(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_unit_to_steps(
number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"] number=self.lr_scheduler_interval.number, unit=self.lr_scheduler_interval.unit
) )
@property @property

View file

@ -1,4 +1,5 @@
import random import random
from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Any, Callable, Iterable from typing import Any, Callable, Iterable
@ -7,7 +8,6 @@ import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from torch import Tensor, cuda, nn from torch import Tensor, cuda, nn
from typing_extensions import TypedDict
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed
@ -79,26 +79,32 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
return decorator return decorator
class TimeUnit(Enum): class TimeUnit(str, Enum):
STEP = "step" STEP = "step"
EPOCH = "epoch" EPOCH = "epoch"
ITERATION = "iteration" ITERATION = "iteration"
DEFAULT = "step" DEFAULT = "step"
class TimeValue(TypedDict): @dataclass
class TimeValue:
number: int number: int
unit: TimeUnit unit: TimeUnit
def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue: TimeValueInput = str | int | dict[str, str | int] | TimeValue
def parse_number_unit_field(value: TimeValueInput) -> TimeValue:
match value: match value:
case str(value_str): case str(value_str):
number, unit = value_str.split(sep=":") number, unit = value_str.split(sep=":")
return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())} return TimeValue(number=int(number.strip()), unit=TimeUnit(value=unit.strip().lower()))
case int(number): case int(number):
return {"number": number, "unit": TimeUnit.DEFAULT} return TimeValue(number=number, unit=TimeUnit.DEFAULT)
case {"number": int(number), "unit": str(unit)}: case {"number": int(number), "unit": str(unit)}:
return {"number": number, "unit": TimeUnit(value=unit.lower())} return TimeValue(number=number, unit=TimeUnit(value=unit.lower()))
case TimeValue(number, unit):
return TimeValue(number=number, unit=unit)
case _: case _:
raise ValueError(f"Unsupported value format: {value}") raise ValueError(f"Unsupported value format: {value}")

View file

@ -23,11 +23,11 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
device: str = "cpu" device: str = "cpu"
dtype: str = "float32" dtype: str = "float32"
duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} duration: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
seed: int = 0 seed: int = 0
batch_size: int = 1 batch_size: int = 1
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP} gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP)
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
evaluation_seed: int = 0 evaluation_seed: int = 0
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@ -63,8 +63,8 @@ class LRSchedulerType(str, Enum):
class LRSchedulerConfig(BaseModel): class LRSchedulerConfig(BaseModel):
type: LRSchedulerType = LRSchedulerType.DEFAULT type: LRSchedulerType = LRSchedulerType.DEFAULT
update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} update_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION)
warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION} warmup: TimeValue = TimeValue(number=0, unit=TimeUnit.ITERATION)
gamma: float = 0.1 gamma: float = 0.1
lr_lambda: Callable[[int], float] | None = None lr_lambda: Callable[[int], float] | None = None
mode: Literal["min", "max"] = "min" mode: Literal["min", "max"] = "min"

View file

@ -241,7 +241,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@cached_property @cached_property
def lr_scheduler(self) -> LRScheduler: def lr_scheduler(self) -> LRScheduler:
config = self.config.lr_scheduler config = self.config.lr_scheduler
scheduler_step_size = config.update_interval["number"] scheduler_step_size = config.update_interval.number
match config.type: match config.type:
case LRSchedulerType.CONSTANT_LR: case LRSchedulerType.CONSTANT_LR:
@ -286,7 +286,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
case _: case _:
raise ValueError(f"Unknown scheduler type: {config.type}") raise ValueError(f"Unknown scheduler type: {config.type}")
warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"]) warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval.unit)
if warmup_scheduler_steps > 0: if warmup_scheduler_steps > 0:
lr_scheduler = WarmupScheduler( lr_scheduler = WarmupScheduler(
optimizer=self.optimizer, optimizer=self.optimizer,

View file

@ -0,0 +1,33 @@
import pytest
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field
@pytest.mark.parametrize(
"input_value, expected_output",
[
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)),
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)),
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)),
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
],
)
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue):
result = parse_number_unit_field(input_value)
assert result == expected_output
@pytest.mark.parametrize(
"invalid_input",
[
"invalid:input",
{"number": "not_a_number", "unit": "step"},
{"invalid_key": 10},
None,
],
)
def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput):
with pytest.raises(ValueError):
parse_number_unit_field(invalid_input)

View file

@ -10,7 +10,7 @@ from torch.optim import SGD
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm from refiners.fluxion.utils import norm
from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number from refiners.training_utils.common import TimeUnit, TimeValue, count_learnable_parameters, human_readable_number
from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.trainer import ( from refiners.training_utils.trainer import (
Trainer, Trainer,
@ -96,7 +96,7 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
@pytest.fixture @pytest.fixture
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer: def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
mock_config_short = mock_config.model_copy(deep=True) mock_config_short = mock_config.model_copy(deep=True)
mock_config_short.training.duration = {"number": 3, "unit": TimeUnit.STEP} mock_config_short.training.duration = TimeValue(number=3, unit=TimeUnit.STEP)
return MockTrainer(config=mock_config_short) return MockTrainer(config=mock_config_short)
@ -130,10 +130,10 @@ def training_clock() -> TrainingClock:
return TrainingClock( return TrainingClock(
dataset_length=100, dataset_length=100,
batch_size=10, batch_size=10,
training_duration={"number": 5, "unit": TimeUnit.EPOCH}, training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH}, gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH}, evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH}, lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
) )
@ -183,7 +183,7 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
clock = mock_trainer.clock clock = mock_trainer.clock
config = mock_trainer.config config = mock_trainer.config
assert clock.num_step_per_iteration == config.training.gradient_accumulation["number"] assert clock.num_step_per_iteration == config.training.gradient_accumulation.number
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size
assert mock_trainer.step_counter == 0 assert mock_trainer.step_counter == 0
@ -191,8 +191,8 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
mock_trainer.train() mock_trainer.train()
assert clock.epoch == config.training.duration["number"] assert clock.epoch == config.training.duration.number
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch assert clock.step == config.training.duration.number * clock.num_batches_per_epoch
assert mock_trainer.step_counter == mock_trainer.clock.step assert mock_trainer.step_counter == mock_trainer.clock.step
@ -206,7 +206,7 @@ def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
mock_trainer_short.train() mock_trainer_short.train()
assert clock.step == config.training.duration["number"] assert clock.step == config.training.duration.number
@pytest.fixture @pytest.fixture