Fix clock log order ; fix that the first iteration was skipped

This commit is contained in:
limiteinductive 2024-05-21 09:36:44 +00:00 committed by Benjamin Trom
parent cc7b62f090
commit 3a7f14e4dc
2 changed files with 7 additions and 4 deletions

View file

@ -116,16 +116,19 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
if self.num_minibatches_processed == 0:
if self.iteration > 0:
self.log(f"Iteration {self.iteration - 1} ended.")
self.log(f"Iteration {self.iteration} started.")
self.log(f"Step {self.step} started.")
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {self.step} ended.")
self.step += 1
self.num_batches_processed += 1
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.num_minibatches_processed += 1
self.num_batches_processed += 1
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Iteration {self.iteration} ended.")
self.iteration += 1
self.num_minibatches_processed = 0

View file

@ -244,11 +244,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
mock_trainer.train()
# Check that the callback skips every other iteration
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 + 1
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1
# Check that the random seed was set
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
assert mock_trainer.mock_callback.optimizer_step_random_int == 41
assert mock_trainer.mock_callback.step_end_random_int == 81