Refactor TimeValue

This commit is contained in:
limiteinductive 2024-04-18 14:41:57 +00:00 committed by Benjamin Trom
parent 17246708b9
commit 446796da57
9 changed files with 148 additions and 91 deletions

View file

@ -219,15 +219,14 @@ We will now define the configuration for the autoencoder. It holds the configura
Example: Example:
```python ```python
from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers, Epoch
from refiners.training_utils.common import TimeUnit, TimeValue
class AutoencoderConfig(BaseConfig): class AutoencoderConfig(BaseConfig):
# Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size. # Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size.
epoch_size: int = 2048 epoch_size: int = 2048
training = TrainingConfig( training = TrainingConfig(
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), duration=Epoch(1000),
batch_size=32, batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu", device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32" dtype="float32"
@ -336,11 +335,11 @@ We can also evaluate the model using the `compute_evaluation` method.
```python ```python
training = TrainingConfig( training = TrainingConfig(
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), duration=Epoch(1000)
batch_size=32, batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu", device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32", dtype="float32",
evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH), evaluation_interval=Epoch(50),
) )
class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]):
@ -478,9 +477,8 @@ You can train this toy model using the code below:
TrainingConfig, TrainingConfig,
register_callback, register_callback,
register_model, register_model,
Epoch,
) )
from refiners.training_utils.common import TimeUnit, TimeValue
class ConvBlock(fl.Chain): class ConvBlock(fl.Chain):
def __init__(self, in_channels: int, out_channels: int) -> None: def __init__(self, in_channels: int, out_channels: int) -> None:
@ -628,11 +626,11 @@ You can train this toy model using the code below:
) )
training = TrainingConfig( training = TrainingConfig(
duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), duration=Epoch(1000),
batch_size=32, batch_size=32,
device="cuda" if torch.cuda.is_available() else "cpu", device="cuda" if torch.cuda.is_available() else "cpu",
dtype="float32", dtype="float32",
evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH), evaluation_interval=Epoch(50),
) )
optimizer = OptimizerConfig( optimizer = OptimizerConfig(
@ -702,9 +700,9 @@ You can train this toy model using the code below:
axes[i, 1].axis("off") axes[i, 1].axis("off")
axes[i, 1].set_title("Reconstructed") axes[i, 1].set_title("Reconstructed")
plt.tight_layout() plt.tight_layout() # type: ignore
plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore
plt.close() plt.close() # type: ignore
@register_callback() @register_callback()
def logging(self, config: CallbackConfig) -> LoggingCallback: def logging(self, config: CallbackConfig) -> LoggingCallback:

View file

@ -31,6 +31,7 @@ for dep in refiners_requires:
sys.exit(1) sys.exit(1)
from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
from refiners.training_utils.config import ( from refiners.training_utils.config import (
BaseConfig, BaseConfig,
LRSchedulerConfig, LRSchedulerConfig,
@ -59,4 +60,9 @@ __all__ = [
"ClockConfig", "ClockConfig",
"Optimizers", "Optimizers",
"LRSchedulerType", "LRSchedulerType",
"TimeValue",
"TimeUnit",
"Epoch",
"Iteration",
"Step",
] ]

View file

@ -3,7 +3,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.common import TimeUnit, TimeValue from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue
if TYPE_CHECKING: if TYPE_CHECKING:
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig
@ -52,47 +52,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@cached_property @cached_property
def unit_to_steps(self) -> dict[TimeUnit, int]: def unit_to_steps(self) -> dict[TimeUnit, int]:
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation.unit == TimeUnit.EPOCH else 1 iteration_factor = self.num_batches_per_epoch if isinstance(self.gradient_accumulation, Epoch) else 1
return { return {
TimeUnit.STEP: 1, Step: 1,
TimeUnit.EPOCH: self.num_batches_per_epoch, Epoch: self.num_batches_per_epoch,
TimeUnit.ITERATION: self.gradient_accumulation.number * iteration_factor, Iteration: self.gradient_accumulation.number * iteration_factor,
} }
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: def convert_time_value_to_steps(self, time_value: TimeValue) -> int:
return number * self.unit_to_steps[unit] return time_value.number * self.unit_to_steps[time_value.unit]
def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int: def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int:
return steps // self.unit_to_steps[unit] return steps // self.unit_to_steps[unit]
def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int:
number, unit = time_value.number, time_value.unit steps = self.convert_time_value_to_steps(time_value=time_value)
steps = self.convert_time_unit_to_steps(number=number, unit=unit)
return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) return self.convert_steps_to_time_unit(steps=steps, unit=target_unit)
@cached_property @cached_property
def num_epochs(self) -> int: def num_epochs(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH) return self.convert_time_value(time_value=self.training_duration, target_unit=Epoch)
@cached_property @cached_property
def num_iterations(self) -> int: def num_iterations(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION) return self.convert_time_value(time_value=self.training_duration, target_unit=Iteration)
@cached_property @cached_property
def num_steps(self) -> int: def num_steps(self) -> int:
return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP) return self.convert_time_value(time_value=self.training_duration, target_unit=Step)
@cached_property @cached_property
def num_step_per_iteration(self) -> int: def num_step_per_iteration(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_value_to_steps(self.gradient_accumulation)
number=self.gradient_accumulation.number, unit=self.gradient_accumulation.unit
)
@cached_property @cached_property
def num_step_per_evaluation(self) -> int: def num_step_per_evaluation(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_value_to_steps(self.evaluation_interval)
number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
)
def reset(self) -> None: def reset(self) -> None:
self.start_time = None self.start_time = None
@ -116,15 +111,11 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
@cached_property @cached_property
def evaluation_interval_steps(self) -> int: def evaluation_interval_steps(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_value_to_steps(self.evaluation_interval)
number=self.evaluation_interval.number, unit=self.evaluation_interval.unit
)
@cached_property @cached_property
def lr_scheduler_interval_steps(self) -> int: def lr_scheduler_interval_steps(self) -> int:
return self.convert_time_unit_to_steps( return self.convert_time_value_to_steps(self.lr_scheduler_interval)
number=self.lr_scheduler_interval.number, unit=self.lr_scheduler_interval.unit
)
@property @property
def is_optimizer_step(self) -> bool: def is_optimizer_step(self) -> bool:

View file

@ -1,7 +1,6 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from typing import Any, Callable, Iterable, Protocol, runtime_checkable
from typing import Any, Callable, Iterable
import numpy as np import numpy as np
import torch import torch
@ -83,32 +82,67 @@ class scoped_seed:
cuda.set_rng_state(self.cuda_torch_state) cuda.set_rng_state(self.cuda_torch_state)
class TimeUnit(str, Enum): @dataclass
STEP = "step" @runtime_checkable
EPOCH = "epoch" class TimeValue(Protocol):
ITERATION = "iteration" number: int
DEFAULT = "step"
@property
def unit(self) -> "TimeUnit":
match self.__class__.__name__:
case "Step":
return Step
case "Epoch":
return Epoch
case "Iteration":
return Iteration
case _:
raise ValueError(f"Unsupported time unit: {self.__class__.__name__}")
@classmethod
def from_str(cls, value: str) -> "TimeValue":
match cls.extract_number_unit(value):
case number, "step":
return Step(number)
case number, "epoch":
return Epoch(number)
case number, "iteration":
return Iteration(number)
case _:
raise ValueError(f"Incorrect time value format: {value}")
@staticmethod
def extract_number_unit(value: str) -> tuple[int, str]:
number, unit = value.lower().split(":")
return int(number.strip()), unit.strip()
@dataclass @dataclass
class TimeValue: class Step(TimeValue):
number: int number: int
unit: TimeUnit
@dataclass
class Epoch(TimeValue):
number: int
@dataclass
class Iteration(TimeValue):
number: int
TimeUnit = type[Step] | type[Epoch] | type[Iteration]
TimeValueInput = str | int | dict[str, str | int] | TimeValue TimeValueInput = str | int | dict[str, str | int] | TimeValue
def parse_number_unit_field(value: TimeValueInput) -> TimeValue: def parse_number_unit_field(value: TimeValueInput) -> TimeValue:
match value: match value:
case str(value_str): case str(value_str):
number, unit = value_str.split(sep=":") return TimeValue.from_str(value_str)
return TimeValue(number=int(number.strip()), unit=TimeUnit(value=unit.strip().lower()))
case int(number): case int(number):
return TimeValue(number=number, unit=TimeUnit.DEFAULT) return Step(number=number)
case {"number": int(number), "unit": str(unit)}: case TimeValue(number):
return TimeValue(number=number, unit=TimeUnit(value=unit.lower())) return value
case TimeValue(number, unit):
return TimeValue(number=number, unit=unit)
case _: case _:
raise ValueError(f"Unsupported value format: {value}") raise ValueError(f"Unsupported value format: {value}")

View file

@ -11,7 +11,7 @@ from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
# PyTorch optimizer parameters type # PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -22,18 +22,18 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
device: str = "cpu" device: str = "cpu"
dtype: str = "float32" dtype: str = "float32"
duration: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION)
seed: int = 0 seed: int = 0
batch_size: int = 1 batch_size: int = 1
gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP) gradient_accumulation: Step | Epoch = Step(1)
evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) evaluation_interval: Iteration | Epoch = Iteration(1)
gradient_clipping_max_norm: float | None = None gradient_clipping_max_norm: float | None = None
evaluation_seed: int = 0 evaluation_seed: int = 0
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before") @field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
def parse_field(cls, value: Any) -> TimeValue: def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value) return parse_number_unit_field(value)
@ -63,8 +63,8 @@ class LRSchedulerType(str, Enum):
class LRSchedulerConfig(BaseModel): class LRSchedulerConfig(BaseModel):
type: LRSchedulerType = LRSchedulerType.DEFAULT type: LRSchedulerType = LRSchedulerType.DEFAULT
update_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) update_interval: Iteration | Epoch = Iteration(1)
warmup: TimeValue = TimeValue(number=0, unit=TimeUnit.ITERATION) warmup: TimeValue = Iteration(0)
gamma: float = 0.1 gamma: float = 0.1
lr_lambda: Callable[[int], float] | None = None lr_lambda: Callable[[int], float] | None = None
mode: Literal["min", "max"] = "min" mode: Literal["min", "max"] = "min"

View file

@ -22,5 +22,5 @@ learning_rate = 1
[lr_scheduler] [lr_scheduler]
type = "ConstantLR" type = "ConstantLR"
update_interval = "1:step" update_interval = "1:iteration"
warmup = "20:step" warmup = "20:iteration"

View file

@ -23,4 +23,4 @@ learning_rate = 1
[lr_scheduler] [lr_scheduler]
type = "ConstantLR" type = "ConstantLR"
update_interval = "1:step" update_interval = "1:iteration"

View file

@ -3,21 +3,41 @@ import random
import pytest import pytest
import torch import torch
from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed from refiners.training_utils.common import (
Epoch,
Iteration,
Step,
TimeValue,
TimeValueInput,
parse_number_unit_field,
scoped_seed,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_value, expected_output", "input_value, expected_output",
[ [
("10: step", TimeValue(number=10, unit=TimeUnit.STEP)), ("3 : steP", Step(3)),
("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)), ("5: epoch", Epoch(5)),
("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)), (" 7:Iteration", Iteration(7)),
(50, TimeValue(number=50, unit=TimeUnit.DEFAULT)),
({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)),
(TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)),
], ],
) )
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue): def test_time_value_from_str(input_value: str, expected_output: TimeValue) -> None:
result = TimeValue.from_str(input_value)
assert result == expected_output
@pytest.mark.parametrize(
"input_value, expected_output",
[
("10: step", Step(10)),
("20 :epoch", Epoch(20)),
("30: Iteration", Iteration(30)),
(50, Step(50)),
(Epoch(200), Epoch(200)),
],
)
def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue) -> None:
result = parse_number_unit_field(input_value) result = parse_number_unit_field(input_value)
assert result == expected_output assert result == expected_output
@ -26,8 +46,8 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T
"invalid_input", "invalid_input",
[ [
"invalid:input", "invalid:input",
{"number": "not_a_number", "unit": "step"}, "10: invalid",
{"invalid_key": 10}, "10",
None, None,
], ],
) )

View file

@ -10,7 +10,13 @@ from torch.optim import SGD
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm from refiners.fluxion.utils import norm
from refiners.training_utils.common import TimeUnit, TimeValue, count_learnable_parameters, human_readable_number from refiners.training_utils.common import (
Epoch,
Iteration,
Step,
count_learnable_parameters,
human_readable_number,
)
from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.trainer import ( from refiners.training_utils.trainer import (
Trainer, Trainer,
@ -96,7 +102,7 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
@pytest.fixture @pytest.fixture
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer: def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
mock_config_short = mock_config.model_copy(deep=True) mock_config_short = mock_config.model_copy(deep=True)
mock_config_short.training.duration = TimeValue(number=3, unit=TimeUnit.STEP) mock_config_short.training.duration = Step(3)
return MockTrainer(config=mock_config_short) return MockTrainer(config=mock_config_short)
@ -130,10 +136,10 @@ def training_clock() -> TrainingClock:
return TrainingClock( return TrainingClock(
dataset_length=100, dataset_length=100,
batch_size=10, batch_size=10,
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), training_duration=Epoch(5),
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), gradient_accumulation=Epoch(1),
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), evaluation_interval=Epoch(1),
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), lr_scheduler_interval=Epoch(1),
) )
@ -142,10 +148,10 @@ def test_small_dataset_error():
TrainingClock( TrainingClock(
dataset_length=3, dataset_length=3,
batch_size=10, batch_size=10,
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), training_duration=Epoch(5),
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), gradient_accumulation=Epoch(1),
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), evaluation_interval=Epoch(1),
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), lr_scheduler_interval=Epoch(1),
) )
@ -154,23 +160,25 @@ def test_zero_batch_size_error():
TrainingClock( TrainingClock(
dataset_length=3, dataset_length=3,
batch_size=0, batch_size=0,
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), training_duration=Epoch(5),
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), gradient_accumulation=Epoch(1),
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), evaluation_interval=Epoch(1),
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), lr_scheduler_interval=Epoch(1),
) )
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10 assert training_clock.convert_time_value_to_steps(Epoch(1)) == 10
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20 assert training_clock.convert_time_value_to_steps(Epoch(2)) == 20
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.STEP) == 1 assert training_clock.convert_time_value_to_steps(Step(1)) == 1
assert training_clock.convert_time_value_to_steps(Iteration(1)) == 10
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None: def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None:
assert training_clock.convert_steps_to_time_unit(10, TimeUnit.EPOCH) == 1 assert training_clock.convert_steps_to_time_unit(10, Epoch) == 1
assert training_clock.convert_steps_to_time_unit(20, TimeUnit.EPOCH) == 2 assert training_clock.convert_steps_to_time_unit(20, Epoch) == 2
assert training_clock.convert_steps_to_time_unit(1, TimeUnit.STEP) == 1 assert training_clock.convert_steps_to_time_unit(1, Step) == 1
assert training_clock.convert_steps_to_time_unit(10, Iteration) == 1
def test_clock_properties(training_clock: TrainingClock) -> None: def test_clock_properties(training_clock: TrainingClock) -> None: