diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index c572c45..91214ee 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -1,9 +1,8 @@ import time -from functools import cached_property from typing import TYPE_CHECKING, Any from refiners.training_utils.callback import Callback, CallbackConfig -from refiners.training_utils.common import Epoch, Iteration, Step, TimeUnit, TimeValue +from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue if TYPE_CHECKING: from refiners.training_utils.config import BaseConfig @@ -23,7 +22,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): self, batch_size: int, training_duration: TimeValue, - gradient_accumulation: int, + gradient_accumulation: Step, lr_scheduler_interval: TimeValue, verbose: bool = True, ) -> None: @@ -75,7 +74,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @property def is_optimizer_step(self) -> bool: - return self.num_minibatches_processed == self.gradient_accumulation + return self.num_minibatches_processed == self.gradient_accumulation.number @property def done(self) -> bool: @@ -94,41 +93,42 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): logger.info(message) def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.reset() - trainer.clock.start_timer() + self.log(f"Starting training for {self.training_duration}.") + self.reset() + self.start_timer() def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.stop_timer() + self.stop_timer() self.log( ( "Training took: " - f"{trainer.clock.time_elapsed} seconds, " - f"{trainer.clock.iteration} iterations, " - f"{trainer.clock.epoch} epochs, " - f"{trainer.clock.step} steps." + f"{self.time_elapsed} seconds, " + f"{self.iteration} iterations, " + f"{self.epoch} epochs, " + f"{self.step} steps." ) ) def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.log(f"Epoch {trainer.clock.epoch} started.") + self.log(f"Epoch {self.epoch} started.") def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.log(f"Epoch {trainer.clock.epoch} ended.") - trainer.clock.epoch += 1 - trainer.clock.num_batches_processed = 0 + self.log(f"Epoch {self.epoch} ended.") + self.epoch += 1 + self.num_batches_processed = 0 def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: if self.num_minibatches_processed == 0: - self.log(f"Iteration {trainer.clock.iteration} started.") - self.log(f"Step {trainer.clock.step} started.") + self.log(f"Iteration {self.iteration} started.") + self.log(f"Step {self.step} started.") def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.log(f"Step {trainer.clock.step} ended.") - trainer.clock.step += 1 - trainer.clock.num_batches_processed += 1 - trainer.clock.num_minibatches_processed += 1 + self.log(f"Step {self.step} ended.") + self.step += 1 + self.num_batches_processed += 1 + self.num_minibatches_processed += 1 def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.log(f"Iteration {trainer.clock.iteration} ended.") - trainer.clock.iteration += 1 - trainer.clock.num_minibatches_processed = 0 + self.log(f"Iteration {self.iteration} ended.") + self.iteration += 1 + self.num_minibatches_processed = 0 diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index f2dfdca..a6220c5 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -11,7 +11,7 @@ from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer from refiners.training_utils.clock import ClockConfig -from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field +from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced @@ -25,12 +25,12 @@ class TrainingConfig(BaseModel): duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) seed: int = 0 batch_size: int = 1 - gradient_accumulation: int = 1 + gradient_accumulation: Step = Step(1) gradient_clipping_max_norm: float | None = None model_config = ConfigDict(extra="forbid") - @field_validator("duration", mode="before") + @field_validator("duration", "gradient_accumulation", mode="before") def parse_field(cls, value: TimeValueInput) -> TimeValue: return parse_number_unit_field(value) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index ce1c768..c8bca70 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -282,7 +282,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): warmup_scheduler_steps = ( config.warmup.number if isinstance(config.warmup, Step) - else config.warmup.number * self.clock.gradient_accumulation + else config.warmup.number * self.clock.gradient_accumulation.number ) if warmup_scheduler_steps > 0: lr_scheduler = WarmupScheduler( @@ -350,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") - scaled_loss = self.loss / self.config.training.gradient_accumulation + scaled_loss = self.loss / self.config.training.gradient_accumulation.number backward(tensors=scaled_loss) self._call_callbacks(event_name="on_backward_end") if self.clock.is_optimizer_step: diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 617c2b4..a147164 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -17,7 +17,7 @@ seed = 0 device = "cpu" dtype = "float32" batch_size = 4 -gradient_accumulation = 4 +gradient_accumulation = "4:step" gradient_clipping_max_norm = 1.0 [optimizer] diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 5e2471b..641bef2 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -12,7 +12,7 @@ verbose = false duration = "100:epoch" seed = 0 batch_size = 4 -gradient_accumulation = 4 +gradient_accumulation = "4:step" gradient_clipping_max_norm = 1.0 [optimizer] diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index ba3ca2c..c4cf612 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -206,27 +206,17 @@ def training_clock() -> TrainingClock: return TrainingClock( batch_size=10, training_duration=Epoch(5), - gradient_accumulation=1, + gradient_accumulation=Step(1), lr_scheduler_interval=Epoch(1), ) -def test_small_dataset_error(): - with pytest.raises(AssertionError): - TrainingClock( - batch_size=10, - training_duration=Epoch(5), - gradient_accumulation=1, - lr_scheduler_interval=Epoch(1), - ) - - def test_zero_batch_size_error(): with pytest.raises(AssertionError): TrainingClock( batch_size=0, training_duration=Epoch(5), - gradient_accumulation=1, + gradient_accumulation=Step(1), lr_scheduler_interval=Epoch(1), )