diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 11df387..740442f 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -116,16 +116,19 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: if self.num_minibatches_processed == 0: + if self.iteration > 0: + self.log(f"Iteration {self.iteration - 1} ended.") self.log(f"Iteration {self.iteration} started.") self.log(f"Step {self.step} started.") def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Step {self.step} ended.") self.step += 1 - self.num_batches_processed += 1 + + def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.num_minibatches_processed += 1 + self.num_batches_processed += 1 def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.log(f"Iteration {self.iteration} ended.") self.iteration += 1 self.num_minibatches_processed = 0 diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index a4ff334..1f3342d 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -244,11 +244,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None: mock_trainer.train() # Check that the callback skips every other iteration - assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 + assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 + 1 assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1 # Check that the random seed was set - assert mock_trainer.mock_callback.optimizer_step_random_int == 93 + assert mock_trainer.mock_callback.optimizer_step_random_int == 41 assert mock_trainer.mock_callback.step_end_random_int == 81