From 446796da5701edcfcc8d6dcba0cbe5b2fe42b83d Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 18 Apr 2024 14:41:57 +0000 Subject: [PATCH] Refactor TimeValue --- docs/guides/training_101/index.md | 20 +++--- src/refiners/training_utils/__init__.py | 6 ++ src/refiners/training_utils/clock.py | 39 +++++------ src/refiners/training_utils/common.py | 66 ++++++++++++++----- src/refiners/training_utils/config.py | 14 ++-- tests/training_utils/mock_config.toml | 4 +- .../training_utils/mock_config_2_models.toml | 2 +- tests/training_utils/test_common.py | 40 ++++++++--- tests/training_utils/test_trainer.py | 48 ++++++++------ 9 files changed, 148 insertions(+), 91 deletions(-) diff --git a/docs/guides/training_101/index.md b/docs/guides/training_101/index.md index 0159df8..8d1be89 100644 --- a/docs/guides/training_101/index.md +++ b/docs/guides/training_101/index.md @@ -219,15 +219,14 @@ We will now define the configuration for the autoencoder. It holds the configura Example: ```python -from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers -from refiners.training_utils.common import TimeUnit, TimeValue +from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers, Epoch class AutoencoderConfig(BaseConfig): # Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size. epoch_size: int = 2048 training = TrainingConfig( - duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), + duration=Epoch(1000), batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32" @@ -336,11 +335,11 @@ We can also evaluate the model using the `compute_evaluation` method. ```python training = TrainingConfig( - duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), + duration=Epoch(1000) batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32", - evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH), + evaluation_interval=Epoch(50), ) class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): @@ -478,9 +477,8 @@ You can train this toy model using the code below: TrainingConfig, register_callback, register_model, + Epoch, ) - from refiners.training_utils.common import TimeUnit, TimeValue - class ConvBlock(fl.Chain): def __init__(self, in_channels: int, out_channels: int) -> None: @@ -628,11 +626,11 @@ You can train this toy model using the code below: ) training = TrainingConfig( - duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), + duration=Epoch(1000), batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32", - evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH), + evaluation_interval=Epoch(50), ) optimizer = OptimizerConfig( @@ -702,9 +700,9 @@ You can train this toy model using the code below: axes[i, 1].axis("off") axes[i, 1].set_title("Reconstructed") - plt.tight_layout() + plt.tight_layout() # type: ignore plt.savefig(f"result_{trainer.clock.epoch}.png") # type: ignore - plt.close() + plt.close() # type: ignore @register_callback() def logging(self, config: CallbackConfig) -> LoggingCallback: diff --git a/src/refiners/training_utils/__init__.py b/src/refiners/training_utils/__init__.py index 236a40c..f39b762 100644 --- a/src/refiners/training_utils/__init__.py +++ b/src/refiners/training_utils/__init__.py @@ -31,6 +31,7 @@ for dep in refiners_requires: sys.exit(1) from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.clock import ClockConfig +from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue from refiners.training_utils.config import ( BaseConfig, LRSchedulerConfig, @@ -59,4 +60,9 @@ __all__ = [ "ClockConfig", "Optimizers", "LRSchedulerType", + "TimeValue", + "TimeUnit", + "Epoch", + "Iteration", + "Step", ] diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 07ab952..ba0f3f3 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -3,7 +3,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any from refiners.training_utils.callback import Callback, CallbackConfig -from refiners.training_utils.common import TimeUnit, TimeValue +from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue if TYPE_CHECKING: from refiners.training_utils.config import BaseConfig @@ -52,47 +52,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @cached_property def unit_to_steps(self) -> dict[TimeUnit, int]: - iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation.unit == TimeUnit.EPOCH else 1 + iteration_factor = self.num_batches_per_epoch if isinstance(self.gradient_accumulation, Epoch) else 1 return { - TimeUnit.STEP: 1, - TimeUnit.EPOCH: self.num_batches_per_epoch, - TimeUnit.ITERATION: self.gradient_accumulation.number * iteration_factor, + Step: 1, + Epoch: self.num_batches_per_epoch, + Iteration: self.gradient_accumulation.number * iteration_factor, } - def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: - return number * self.unit_to_steps[unit] + def convert_time_value_to_steps(self, time_value: TimeValue) -> int: + return time_value.number * self.unit_to_steps[time_value.unit] def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int: return steps // self.unit_to_steps[unit] def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: - number, unit = time_value.number, time_value.unit - steps = self.convert_time_unit_to_steps(number=number, unit=unit) + steps = self.convert_time_value_to_steps(time_value=time_value) return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) @cached_property def num_epochs(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.EPOCH) + return self.convert_time_value(time_value=self.training_duration, target_unit=Epoch) @cached_property def num_iterations(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.ITERATION) + return self.convert_time_value(time_value=self.training_duration, target_unit=Iteration) @cached_property def num_steps(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=TimeUnit.STEP) + return self.convert_time_value(time_value=self.training_duration, target_unit=Step) @cached_property def num_step_per_iteration(self) -> int: - return self.convert_time_unit_to_steps( - number=self.gradient_accumulation.number, unit=self.gradient_accumulation.unit - ) + return self.convert_time_value_to_steps(self.gradient_accumulation) @cached_property def num_step_per_evaluation(self) -> int: - return self.convert_time_unit_to_steps( - number=self.evaluation_interval.number, unit=self.evaluation_interval.unit - ) + return self.convert_time_value_to_steps(self.evaluation_interval) def reset(self) -> None: self.start_time = None @@ -116,15 +111,11 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @cached_property def evaluation_interval_steps(self) -> int: - return self.convert_time_unit_to_steps( - number=self.evaluation_interval.number, unit=self.evaluation_interval.unit - ) + return self.convert_time_value_to_steps(self.evaluation_interval) @cached_property def lr_scheduler_interval_steps(self) -> int: - return self.convert_time_unit_to_steps( - number=self.lr_scheduler_interval.number, unit=self.lr_scheduler_interval.unit - ) + return self.convert_time_value_to_steps(self.lr_scheduler_interval) @property def is_optimizer_step(self) -> bool: diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py index 1fe9e1d..2e193fa 100644 --- a/src/refiners/training_utils/common.py +++ b/src/refiners/training_utils/common.py @@ -1,7 +1,6 @@ import random from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, Protocol, runtime_checkable import numpy as np import torch @@ -83,32 +82,67 @@ class scoped_seed: cuda.set_rng_state(self.cuda_torch_state) -class TimeUnit(str, Enum): - STEP = "step" - EPOCH = "epoch" - ITERATION = "iteration" - DEFAULT = "step" +@dataclass +@runtime_checkable +class TimeValue(Protocol): + number: int + + @property + def unit(self) -> "TimeUnit": + match self.__class__.__name__: + case "Step": + return Step + case "Epoch": + return Epoch + case "Iteration": + return Iteration + case _: + raise ValueError(f"Unsupported time unit: {self.__class__.__name__}") + + @classmethod + def from_str(cls, value: str) -> "TimeValue": + match cls.extract_number_unit(value): + case number, "step": + return Step(number) + case number, "epoch": + return Epoch(number) + case number, "iteration": + return Iteration(number) + case _: + raise ValueError(f"Incorrect time value format: {value}") + + @staticmethod + def extract_number_unit(value: str) -> tuple[int, str]: + number, unit = value.lower().split(":") + return int(number.strip()), unit.strip() @dataclass -class TimeValue: +class Step(TimeValue): number: int - unit: TimeUnit +@dataclass +class Epoch(TimeValue): + number: int + + +@dataclass +class Iteration(TimeValue): + number: int + + +TimeUnit = type[Step] | type[Epoch] | type[Iteration] TimeValueInput = str | int | dict[str, str | int] | TimeValue def parse_number_unit_field(value: TimeValueInput) -> TimeValue: match value: case str(value_str): - number, unit = value_str.split(sep=":") - return TimeValue(number=int(number.strip()), unit=TimeUnit(value=unit.strip().lower())) + return TimeValue.from_str(value_str) case int(number): - return TimeValue(number=number, unit=TimeUnit.DEFAULT) - case {"number": int(number), "unit": str(unit)}: - return TimeValue(number=number, unit=TimeUnit(value=unit.lower())) - case TimeValue(number, unit): - return TimeValue(number=number, unit=unit) + return Step(number=number) + case TimeValue(number): + return value case _: raise ValueError(f"Unsupported value format: {value}") diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index dfe536b..f1b654e 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -11,7 +11,7 @@ from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer from refiners.training_utils.clock import ClockConfig -from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field +from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced @@ -22,18 +22,18 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] class TrainingConfig(BaseModel): device: str = "cpu" dtype: str = "float32" - duration: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) + duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) seed: int = 0 batch_size: int = 1 - gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP) - evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) + gradient_accumulation: Step | Epoch = Step(1) + evaluation_interval: Iteration | Epoch = Iteration(1) gradient_clipping_max_norm: float | None = None evaluation_seed: int = 0 model_config = ConfigDict(extra="forbid") @field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before") - def parse_field(cls, value: Any) -> TimeValue: + def parse_field(cls, value: TimeValueInput) -> TimeValue: return parse_number_unit_field(value) @@ -63,8 +63,8 @@ class LRSchedulerType(str, Enum): class LRSchedulerConfig(BaseModel): type: LRSchedulerType = LRSchedulerType.DEFAULT - update_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) - warmup: TimeValue = TimeValue(number=0, unit=TimeUnit.ITERATION) + update_interval: Iteration | Epoch = Iteration(1) + warmup: TimeValue = Iteration(0) gamma: float = 0.1 lr_lambda: Callable[[int], float] | None = None mode: Literal["min", "max"] = "min" diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index eebb211..0b48702 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -22,5 +22,5 @@ learning_rate = 1 [lr_scheduler] type = "ConstantLR" -update_interval = "1:step" -warmup = "20:step" +update_interval = "1:iteration" +warmup = "20:iteration" diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 302c70b..361890a 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -23,4 +23,4 @@ learning_rate = 1 [lr_scheduler] type = "ConstantLR" -update_interval = "1:step" +update_interval = "1:iteration" diff --git a/tests/training_utils/test_common.py b/tests/training_utils/test_common.py index f490424..cf76b54 100644 --- a/tests/training_utils/test_common.py +++ b/tests/training_utils/test_common.py @@ -3,21 +3,41 @@ import random import pytest import torch -from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field, scoped_seed +from refiners.training_utils.common import ( + Epoch, + Iteration, + Step, + TimeValue, + TimeValueInput, + parse_number_unit_field, + scoped_seed, +) @pytest.mark.parametrize( "input_value, expected_output", [ - ("10: step", TimeValue(number=10, unit=TimeUnit.STEP)), - ("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)), - ("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)), - (50, TimeValue(number=50, unit=TimeUnit.DEFAULT)), - ({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)), - (TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)), + ("3 : steP", Step(3)), + ("5: epoch", Epoch(5)), + (" 7:Iteration", Iteration(7)), ], ) -def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue): +def test_time_value_from_str(input_value: str, expected_output: TimeValue) -> None: + result = TimeValue.from_str(input_value) + assert result == expected_output + + +@pytest.mark.parametrize( + "input_value, expected_output", + [ + ("10: step", Step(10)), + ("20 :epoch", Epoch(20)), + ("30: Iteration", Iteration(30)), + (50, Step(50)), + (Epoch(200), Epoch(200)), + ], +) +def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue) -> None: result = parse_number_unit_field(input_value) assert result == expected_output @@ -26,8 +46,8 @@ def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: T "invalid_input", [ "invalid:input", - {"number": "not_a_number", "unit": "step"}, - {"invalid_key": 10}, + "10: invalid", + "10", None, ], ) diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 940445d..3c7f45c 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -10,7 +10,13 @@ from torch.optim import SGD from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm -from refiners.training_utils.common import TimeUnit, TimeValue, count_learnable_parameters, human_readable_number +from refiners.training_utils.common import ( + Epoch, + Iteration, + Step, + count_learnable_parameters, + human_readable_number, +) from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.trainer import ( Trainer, @@ -96,7 +102,7 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer: @pytest.fixture def mock_trainer_short(mock_config: MockConfig) -> MockTrainer: mock_config_short = mock_config.model_copy(deep=True) - mock_config_short.training.duration = TimeValue(number=3, unit=TimeUnit.STEP) + mock_config_short.training.duration = Step(3) return MockTrainer(config=mock_config_short) @@ -130,10 +136,10 @@ def training_clock() -> TrainingClock: return TrainingClock( dataset_length=100, batch_size=10, - training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), - gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), - evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), - lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + training_duration=Epoch(5), + gradient_accumulation=Epoch(1), + evaluation_interval=Epoch(1), + lr_scheduler_interval=Epoch(1), ) @@ -142,10 +148,10 @@ def test_small_dataset_error(): TrainingClock( dataset_length=3, batch_size=10, - training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), - gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), - evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), - lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + training_duration=Epoch(5), + gradient_accumulation=Epoch(1), + evaluation_interval=Epoch(1), + lr_scheduler_interval=Epoch(1), ) @@ -154,23 +160,25 @@ def test_zero_batch_size_error(): TrainingClock( dataset_length=3, batch_size=0, - training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), - gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), - evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), - lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + training_duration=Epoch(5), + gradient_accumulation=Epoch(1), + evaluation_interval=Epoch(1), + lr_scheduler_interval=Epoch(1), ) def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: - assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10 - assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20 - assert training_clock.convert_time_unit_to_steps(1, TimeUnit.STEP) == 1 + assert training_clock.convert_time_value_to_steps(Epoch(1)) == 10 + assert training_clock.convert_time_value_to_steps(Epoch(2)) == 20 + assert training_clock.convert_time_value_to_steps(Step(1)) == 1 + assert training_clock.convert_time_value_to_steps(Iteration(1)) == 10 def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None: - assert training_clock.convert_steps_to_time_unit(10, TimeUnit.EPOCH) == 1 - assert training_clock.convert_steps_to_time_unit(20, TimeUnit.EPOCH) == 2 - assert training_clock.convert_steps_to_time_unit(1, TimeUnit.STEP) == 1 + assert training_clock.convert_steps_to_time_unit(10, Epoch) == 1 + assert training_clock.convert_steps_to_time_unit(20, Epoch) == 2 + assert training_clock.convert_steps_to_time_unit(1, Step) == 1 + assert training_clock.convert_steps_to_time_unit(10, Iteration) == 1 def test_clock_properties(training_clock: TrainingClock) -> None: