diff --git a/docs/guides/training_101/index.md b/docs/guides/training_101/index.md index 15aa422..0159df8 100644 --- a/docs/guides/training_101/index.md +++ b/docs/guides/training_101/index.md @@ -220,13 +220,14 @@ Example: ```python from refiners.training_utils import BaseConfig, TrainingConfig, OptimizerConfig, LRSchedulerConfig, Optimizers, LRSchedulers +from refiners.training_utils.common import TimeUnit, TimeValue class AutoencoderConfig(BaseConfig): # Since we are using a synthetic dataset, we will use a arbitrary fixed epoch size. epoch_size: int = 2048 training = TrainingConfig( - duration="1000:epoch", + duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32" @@ -335,11 +336,11 @@ We can also evaluate the model using the `compute_evaluation` method. ```python training = TrainingConfig( - duration="1000:epoch", + duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32", - evaluation_interval="50:epoch" # We set the evaluation to be done every 10 epochs + evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH), ) class AutoencoderTrainer(Trainer[AutoencoderConfig, Batch]): @@ -459,6 +460,8 @@ You can train this toy model using the code below: import torch from loguru import logger from PIL import Image + from torch.nn import functional as F + from refiners.fluxion import layers as fl from refiners.fluxion.utils import image_to_tensor, tensor_to_image from refiners.training_utils import ( @@ -476,7 +479,7 @@ You can train this toy model using the code below: register_callback, register_model, ) - from torch.nn import functional as F + from refiners.training_utils.common import TimeUnit, TimeValue class ConvBlock(fl.Chain): @@ -487,7 +490,7 @@ You can train this toy model using the code below: out_channels=out_channels, kernel_size=3, padding=1, - groups=min(in_channels, out_channels) + groups=min(in_channels, out_channels), ), fl.LayerNorm2d(out_channels), fl.SiLU(), @@ -576,9 +579,7 @@ You can train this toy model using the code below: random.seed(seed) while True: - rectangle = Image.new( - "L", (random.randint(1, size), random.randint(1, size)), color=255 - ) + rectangle = Image.new("L", (random.randint(1, size), random.randint(1, size)), color=255) mask = Image.new("L", (size, size)) mask.paste( rectangle, @@ -627,11 +628,11 @@ You can train this toy model using the code below: ) training = TrainingConfig( - duration="1000:epoch", # type: ignore + duration=TimeValue(number=1000, unit=TimeUnit.EPOCH), batch_size=32, device="cuda" if torch.cuda.is_available() else "cpu", dtype="float32", - evaluation_interval="50:epoch", # type: ignore + evaluation_interval=TimeValue(number=50, unit=TimeUnit.EPOCH), ) optimizer = OptimizerConfig( @@ -639,9 +640,7 @@ You can train this toy model using the code below: learning_rate=1e-4, ) - lr_scheduler = LRSchedulerConfig( - type=LRSchedulerType.CONSTANT_LR - ) + lr_scheduler = LRSchedulerConfig(type=LRSchedulerType.CONSTANT_LR) config = AutoencoderConfig( training=training, @@ -672,9 +671,7 @@ You can train this toy model using the code below: return Autoencoder() def compute_loss(self, batch: Batch) -> torch.Tensor: - x_reconstructed = self.autoencoder.decoder( - self.autoencoder.encoder(batch.image) - ) + x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(batch.image)) return F.binary_cross_entropy(x_reconstructed, batch.image) def compute_evaluation(self) -> None: @@ -687,14 +684,14 @@ You can train this toy model using the code below: x_reconstructed = self.autoencoder.decoder(self.autoencoder.encoder(mask)) loss = F.mse_loss(x_reconstructed, mask) validation_losses.append(loss.detach().cpu().item()) - grid.append((tensor_to_image(mask), tensor_to_image((x_reconstructed>0.5).float()))) + grid.append((tensor_to_image(mask), tensor_to_image((x_reconstructed > 0.5).float()))) mean_loss = sum(validation_losses) / len(validation_losses) logger.info(f"Mean validation loss: {mean_loss}, epoch: {self.clock.epoch}") import matplotlib.pyplot as plt - _, axes = plt.subplots(4, 2, figsize=(8, 16)) # type: ignore + _, axes = plt.subplots(4, 2, figsize=(8, 16)) # type: ignore for i, (mask, reconstructed) in enumerate(grid): axes[i, 0].imshow(mask, cmap="gray") diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 49d70a4..ad49d2d 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -48,11 +48,11 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @cached_property def unit_to_steps(self) -> dict[TimeUnit, int]: - iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1 + iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation.unit == TimeUnit.EPOCH else 1 return { TimeUnit.STEP: 1, TimeUnit.EPOCH: self.num_batches_per_epoch, - TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor, + TimeUnit.ITERATION: self.gradient_accumulation.number * iteration_factor, } def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int: @@ -62,7 +62,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): return steps // self.unit_to_steps[unit] def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: - number, unit = time_value["number"], time_value["unit"] + number, unit = time_value.number, time_value.unit steps = self.convert_time_unit_to_steps(number=number, unit=unit) return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) @@ -81,13 +81,13 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @cached_property def num_step_per_iteration(self) -> int: return self.convert_time_unit_to_steps( - number=self.gradient_accumulation["number"], unit=self.gradient_accumulation["unit"] + number=self.gradient_accumulation.number, unit=self.gradient_accumulation.unit ) @cached_property def num_step_per_evaluation(self) -> int: return self.convert_time_unit_to_steps( - number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] + number=self.evaluation_interval.number, unit=self.evaluation_interval.unit ) def reset(self) -> None: @@ -113,13 +113,13 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @cached_property def evaluation_interval_steps(self) -> int: return self.convert_time_unit_to_steps( - number=self.evaluation_interval["number"], unit=self.evaluation_interval["unit"] + number=self.evaluation_interval.number, unit=self.evaluation_interval.unit ) @cached_property def lr_scheduler_interval_steps(self) -> int: return self.convert_time_unit_to_steps( - number=self.lr_scheduler_interval["number"], unit=self.lr_scheduler_interval["unit"] + number=self.lr_scheduler_interval.number, unit=self.lr_scheduler_interval.unit ) @property diff --git a/src/refiners/training_utils/common.py b/src/refiners/training_utils/common.py index b6fb131..46b070c 100644 --- a/src/refiners/training_utils/common.py +++ b/src/refiners/training_utils/common.py @@ -1,4 +1,5 @@ import random +from dataclasses import dataclass from enum import Enum from functools import wraps from typing import Any, Callable, Iterable @@ -7,7 +8,6 @@ import numpy as np import torch from loguru import logger from torch import Tensor, cuda, nn -from typing_extensions import TypedDict from refiners.fluxion.utils import manual_seed @@ -79,26 +79,32 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C return decorator -class TimeUnit(Enum): +class TimeUnit(str, Enum): STEP = "step" EPOCH = "epoch" ITERATION = "iteration" DEFAULT = "step" -class TimeValue(TypedDict): +@dataclass +class TimeValue: number: int unit: TimeUnit -def parse_number_unit_field(value: str | int | dict[str, str | int]) -> TimeValue: +TimeValueInput = str | int | dict[str, str | int] | TimeValue + + +def parse_number_unit_field(value: TimeValueInput) -> TimeValue: match value: case str(value_str): number, unit = value_str.split(sep=":") - return {"number": int(number.strip()), "unit": TimeUnit(value=unit.strip().lower())} + return TimeValue(number=int(number.strip()), unit=TimeUnit(value=unit.strip().lower())) case int(number): - return {"number": number, "unit": TimeUnit.DEFAULT} + return TimeValue(number=number, unit=TimeUnit.DEFAULT) case {"number": int(number), "unit": str(unit)}: - return {"number": number, "unit": TimeUnit(value=unit.lower())} + return TimeValue(number=number, unit=TimeUnit(value=unit.lower())) + case TimeValue(number, unit): + return TimeValue(number=number, unit=unit) case _: raise ValueError(f"Unsupported value format: {value}") diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 9c953a3..254d7e1 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -23,11 +23,11 @@ ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] class TrainingConfig(BaseModel): device: str = "cpu" dtype: str = "float32" - duration: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} + duration: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) seed: int = 0 batch_size: int = 1 - gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP} - evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} + gradient_accumulation: TimeValue = TimeValue(number=1, unit=TimeUnit.STEP) + evaluation_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) evaluation_seed: int = 0 model_config = ConfigDict(extra="forbid") @@ -63,8 +63,8 @@ class LRSchedulerType(str, Enum): class LRSchedulerConfig(BaseModel): type: LRSchedulerType = LRSchedulerType.DEFAULT - update_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} - warmup: TimeValue = {"number": 0, "unit": TimeUnit.ITERATION} + update_interval: TimeValue = TimeValue(number=1, unit=TimeUnit.ITERATION) + warmup: TimeValue = TimeValue(number=0, unit=TimeUnit.ITERATION) gamma: float = 0.1 lr_lambda: Callable[[int], float] | None = None mode: Literal["min", "max"] = "min" diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 3875a9b..6905c31 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -241,7 +241,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): @cached_property def lr_scheduler(self) -> LRScheduler: config = self.config.lr_scheduler - scheduler_step_size = config.update_interval["number"] + scheduler_step_size = config.update_interval.number match config.type: case LRSchedulerType.CONSTANT_LR: @@ -286,7 +286,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): case _: raise ValueError(f"Unknown scheduler type: {config.type}") - warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval["unit"]) + warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval.unit) if warmup_scheduler_steps > 0: lr_scheduler = WarmupScheduler( optimizer=self.optimizer, diff --git a/tests/training_utils/test_common.py b/tests/training_utils/test_common.py new file mode 100644 index 0000000..db036f6 --- /dev/null +++ b/tests/training_utils/test_common.py @@ -0,0 +1,33 @@ +import pytest + +from refiners.training_utils.common import TimeUnit, TimeValue, TimeValueInput, parse_number_unit_field + + +@pytest.mark.parametrize( + "input_value, expected_output", + [ + ("10: step", TimeValue(number=10, unit=TimeUnit.STEP)), + ("20 :epoch", TimeValue(number=20, unit=TimeUnit.EPOCH)), + ("30: Iteration", TimeValue(number=30, unit=TimeUnit.ITERATION)), + (50, TimeValue(number=50, unit=TimeUnit.DEFAULT)), + ({"number": 100, "unit": "STEP"}, TimeValue(number=100, unit=TimeUnit.STEP)), + (TimeValue(number=200, unit=TimeUnit.EPOCH), TimeValue(number=200, unit=TimeUnit.EPOCH)), + ], +) +def test_parse_number_unit_field(input_value: TimeValueInput, expected_output: TimeValue): + result = parse_number_unit_field(input_value) + assert result == expected_output + + +@pytest.mark.parametrize( + "invalid_input", + [ + "invalid:input", + {"number": "not_a_number", "unit": "step"}, + {"invalid_key": 10}, + None, + ], +) +def test_parse_number_unit_field_invalid_input(invalid_input: TimeValueInput): + with pytest.raises(ValueError): + parse_number_unit_field(invalid_input) diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 617e506..19fa462 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -10,7 +10,7 @@ from torch.optim import SGD from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm -from refiners.training_utils.common import TimeUnit, count_learnable_parameters, human_readable_number +from refiners.training_utils.common import TimeUnit, TimeValue, count_learnable_parameters, human_readable_number from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.trainer import ( Trainer, @@ -96,7 +96,7 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer: @pytest.fixture def mock_trainer_short(mock_config: MockConfig) -> MockTrainer: mock_config_short = mock_config.model_copy(deep=True) - mock_config_short.training.duration = {"number": 3, "unit": TimeUnit.STEP} + mock_config_short.training.duration = TimeValue(number=3, unit=TimeUnit.STEP) return MockTrainer(config=mock_config_short) @@ -130,10 +130,10 @@ def training_clock() -> TrainingClock: return TrainingClock( dataset_length=100, batch_size=10, - training_duration={"number": 5, "unit": TimeUnit.EPOCH}, - gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH}, - evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH}, - lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH}, + training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), + gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), + evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), ) @@ -183,7 +183,7 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None: clock = mock_trainer.clock config = mock_trainer.config - assert clock.num_step_per_iteration == config.training.gradient_accumulation["number"] + assert clock.num_step_per_iteration == config.training.gradient_accumulation.number assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size assert mock_trainer.step_counter == 0 @@ -191,8 +191,8 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None: mock_trainer.train() - assert clock.epoch == config.training.duration["number"] - assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch + assert clock.epoch == config.training.duration.number + assert clock.step == config.training.duration.number * clock.num_batches_per_epoch assert mock_trainer.step_counter == mock_trainer.clock.step @@ -206,7 +206,7 @@ def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: mock_trainer_short.train() - assert clock.step == config.training.duration["number"] + assert clock.step == config.training.duration.number @pytest.fixture