mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
Fix clock log order ; fix that the first iteration was skipped
This commit is contained in:
parent
cc7b62f090
commit
3a7f14e4dc
|
@ -116,16 +116,19 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
|
||||||
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
if self.num_minibatches_processed == 0:
|
if self.num_minibatches_processed == 0:
|
||||||
|
if self.iteration > 0:
|
||||||
|
self.log(f"Iteration {self.iteration - 1} ended.")
|
||||||
self.log(f"Iteration {self.iteration} started.")
|
self.log(f"Iteration {self.iteration} started.")
|
||||||
self.log(f"Step {self.step} started.")
|
self.log(f"Step {self.step} started.")
|
||||||
|
|
||||||
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Step {self.step} ended.")
|
self.log(f"Step {self.step} ended.")
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.num_batches_processed += 1
|
|
||||||
|
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.num_minibatches_processed += 1
|
self.num_minibatches_processed += 1
|
||||||
|
self.num_batches_processed += 1
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Iteration {self.iteration} ended.")
|
|
||||||
self.iteration += 1
|
self.iteration += 1
|
||||||
self.num_minibatches_processed = 0
|
self.num_minibatches_processed = 0
|
||||||
|
|
|
@ -244,11 +244,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
|
||||||
mock_trainer.train()
|
mock_trainer.train()
|
||||||
|
|
||||||
# Check that the callback skips every other iteration
|
# Check that the callback skips every other iteration
|
||||||
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 + 1
|
||||||
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1
|
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1
|
||||||
|
|
||||||
# Check that the random seed was set
|
# Check that the random seed was set
|
||||||
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
assert mock_trainer.mock_callback.optimizer_step_random_int == 41
|
||||||
assert mock_trainer.mock_callback.step_end_random_int == 81
|
assert mock_trainer.mock_callback.step_end_random_int == 81
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue