diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 4fb111f..c572c45 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -21,24 +21,18 @@ class ClockConfig(CallbackConfig): class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): def __init__( self, - dataset_length: int, batch_size: int, training_duration: TimeValue, - gradient_accumulation: TimeValue, + gradient_accumulation: int, lr_scheduler_interval: TimeValue, verbose: bool = True, ) -> None: assert batch_size > 0, "Batch size must be greater than 0." - assert ( - dataset_length >= batch_size - ), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})." - self.dataset_length = dataset_length self.batch_size = batch_size self.training_duration = training_duration self.gradient_accumulation = gradient_accumulation self.lr_scheduler_interval = lr_scheduler_interval self.verbose = verbose - self.num_batches_per_epoch = dataset_length // batch_size self.start_time = None self.end_time = None self.step = 0 @@ -48,43 +42,16 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): self.num_minibatches_processed = 0 self.loss: Tensor | None = None - @cached_property - def unit_to_steps(self) -> dict[TimeUnit, int]: - iteration_factor = self.num_batches_per_epoch if isinstance(self.gradient_accumulation, Epoch) else 1 - return { - Step: 1, - Epoch: self.num_batches_per_epoch, - Iteration: self.gradient_accumulation.number * iteration_factor, - } - - def convert_time_value_to_steps(self, time_value: TimeValue) -> int: - return time_value.number * self.unit_to_steps[time_value.unit] - - def convert_steps_to_time_unit(self, steps: int, unit: TimeUnit) -> int: - return steps // self.unit_to_steps[unit] - - def convert_time_value(self, time_value: TimeValue, target_unit: TimeUnit) -> int: - steps = self.convert_time_value_to_steps(time_value=time_value) - return self.convert_steps_to_time_unit(steps=steps, unit=target_unit) - - @cached_property - def num_epochs(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=Epoch) - - @cached_property - def num_iterations(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=Iteration) - - @cached_property - def num_steps(self) -> int: - return self.convert_time_value(time_value=self.training_duration, target_unit=Step) - - @cached_property - def num_step_per_iteration(self) -> int: - return self.convert_time_value_to_steps(self.gradient_accumulation) - def is_due(self, interval: TimeValue) -> bool: - return self.step % self.convert_time_value_to_steps(interval) == 0 + match interval: + case Step(number): + return self.step % number == 0 + case Iteration(number): + return self.iteration % number == 0 + case Epoch(number): + return self.epoch % number == 0 + case _: + raise ValueError(f"Unsupported TimeValue: {interval}") def reset(self) -> None: self.start_time = None @@ -108,11 +75,19 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): @property def is_optimizer_step(self) -> bool: - return self.num_minibatches_processed == self.num_step_per_iteration + return self.num_minibatches_processed == self.gradient_accumulation @property def done(self) -> bool: - return self.step >= self.num_steps + match self.training_duration: + case Step(number): + return self.step >= number + case Iteration(number): + return self.iteration >= number + case Epoch(number): + return self.epoch >= number + case _: + raise ValueError(f"Unsupported TimeValue: {self.training_duration}") def log(self, message: str, /) -> None: if self.verbose: @@ -120,14 +95,6 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: trainer.clock.reset() - self.log( - ( - "Starting training for a total of: " - f"{trainer.clock.num_steps} steps, " - f"{trainer.clock.num_epochs} epochs, " - f"{trainer.clock.num_iterations} iterations." - ) - ) trainer.clock.start_timer() def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 3fc248f..f2dfdca 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -11,7 +11,7 @@ from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer from refiners.training_utils.clock import ClockConfig -from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field +from refiners.training_utils.common import Epoch, Iteration, TimeValue, TimeValueInput, parse_number_unit_field # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced @@ -25,12 +25,12 @@ class TrainingConfig(BaseModel): duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) seed: int = 0 batch_size: int = 1 - gradient_accumulation: Step | Epoch = Step(1) + gradient_accumulation: int = 1 gradient_clipping_max_norm: float | None = None model_config = ConfigDict(extra="forbid") - @field_validator("duration", "gradient_accumulation", mode="before") + @field_validator("duration", mode="before") def parse_field(cls, value: TimeValueInput) -> TimeValue: return parse_number_unit_field(value) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index c2467e0..ce1c768 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -30,6 +30,7 @@ from refiners.training_utils.callback import ( ) from refiners.training_utils.clock import ClockConfig, TrainingClock from refiners.training_utils.common import ( + Step, compute_grad_norm, count_learnable_parameters, human_readable_number, @@ -150,7 +151,6 @@ class Trainer(Generic[ConfigType, Batch], ABC): @register_callback() def clock(self, config: ClockConfig) -> TrainingClock: return TrainingClock( - dataset_length=self.dataset_length, batch_size=self.config.training.batch_size, training_duration=self.config.training.duration, gradient_accumulation=self.config.training.gradient_accumulation, @@ -279,7 +279,11 @@ class Trainer(Generic[ConfigType, Batch], ABC): case _: raise ValueError(f"Unknown scheduler type: {config.type}") - warmup_scheduler_steps = self.clock.convert_time_value(config.warmup, config.update_interval.unit) + warmup_scheduler_steps = ( + config.warmup.number + if isinstance(config.warmup, Step) + else config.warmup.number * self.clock.gradient_accumulation + ) if warmup_scheduler_steps > 0: lr_scheduler = WarmupScheduler( optimizer=self.optimizer, @@ -346,7 +350,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") - scaled_loss = self.loss / self.clock.num_step_per_iteration + scaled_loss = self.loss / self.config.training.gradient_accumulation backward(tensors=scaled_loss) self._call_callbacks(event_name="on_backward_end") if self.clock.is_optimizer_step: diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index a147164..617c2b4 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -17,7 +17,7 @@ seed = 0 device = "cpu" dtype = "float32" batch_size = 4 -gradient_accumulation = "4:step" +gradient_accumulation = 4 gradient_clipping_max_norm = 1.0 [optimizer] diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 641bef2..5e2471b 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -12,7 +12,7 @@ verbose = false duration = "100:epoch" seed = 0 batch_size = 4 -gradient_accumulation = "4:step" +gradient_accumulation = 4 gradient_clipping_max_norm = 1.0 [optimizer] diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index f929814..ba3ca2c 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -204,10 +204,9 @@ def test_human_readable_number() -> None: @pytest.fixture def training_clock() -> TrainingClock: return TrainingClock( - dataset_length=100, batch_size=10, training_duration=Epoch(5), - gradient_accumulation=Epoch(1), + gradient_accumulation=1, lr_scheduler_interval=Epoch(1), ) @@ -215,10 +214,9 @@ def training_clock() -> TrainingClock: def test_small_dataset_error(): with pytest.raises(AssertionError): TrainingClock( - dataset_length=3, batch_size=10, training_duration=Epoch(5), - gradient_accumulation=Epoch(1), + gradient_accumulation=1, lr_scheduler_interval=Epoch(1), ) @@ -226,35 +224,13 @@ def test_small_dataset_error(): def test_zero_batch_size_error(): with pytest.raises(AssertionError): TrainingClock( - dataset_length=3, batch_size=0, training_duration=Epoch(5), - gradient_accumulation=Epoch(1), + gradient_accumulation=1, lr_scheduler_interval=Epoch(1), ) -def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: - assert training_clock.convert_time_value_to_steps(Epoch(1)) == 10 - assert training_clock.convert_time_value_to_steps(Epoch(2)) == 20 - assert training_clock.convert_time_value_to_steps(Step(1)) == 1 - assert training_clock.convert_time_value_to_steps(Iteration(1)) == 10 - - -def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None: - assert training_clock.convert_steps_to_time_unit(10, Epoch) == 1 - assert training_clock.convert_steps_to_time_unit(20, Epoch) == 2 - assert training_clock.convert_steps_to_time_unit(1, Step) == 1 - assert training_clock.convert_steps_to_time_unit(10, Iteration) == 1 - - -def test_clock_properties(training_clock: TrainingClock) -> None: - assert training_clock.num_batches_per_epoch == 10 - assert training_clock.num_epochs == 5 - assert training_clock.num_iterations == 5 - assert training_clock.num_steps == 50 - - def test_timer_functionality(training_clock: TrainingClock) -> None: training_clock.start_timer() assert training_clock.start_time is not None @@ -275,17 +251,12 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None: clock = mock_trainer.clock config = mock_trainer.config - assert clock.num_step_per_iteration == config.training.gradient_accumulation.number - assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size - assert mock_trainer.step_counter == 0 assert clock.epoch == 0 mock_trainer.train() assert clock.epoch == config.training.duration.number - assert clock.step == config.training.duration.number * clock.num_batches_per_epoch - assert mock_trainer.step_counter == mock_trainer.clock.step