diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index fb54e7e..4fb111f 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -155,7 +155,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): self.log(f"Iteration {trainer.clock.iteration} started.") self.log(f"Step {trainer.clock.step} started.") - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Step {trainer.clock.step} ended.") trainer.clock.step += 1 trainer.clock.num_batches_processed += 1 diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 2776121..c2467e0 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -401,9 +401,20 @@ class Trainer(Generic[ConfigType, Batch], ABC): elif mode == "eval": item.model.eval() + def _run_event(self, callback: Callback[Any], event_name: str) -> None: + getattr(callback, event_name)(self) + def _call_callbacks(self, event_name: str) -> None: + if event_name.endswith("_begin"): + self._run_event(self.clock, event_name) + for callback in self.callbacks.values(): - getattr(callback, event_name)(self) + if callback == self.clock: + continue + self._run_event(callback, event_name) + + if event_name.endswith("_end"): + self._run_event(self.clock, event_name) def _load_callbacks(self) -> None: for name, config in self.config: diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index f7803d2..a9c91a1 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -94,6 +94,10 @@ class MockCallback(Callback["MockTrainer"]): def on_step_end(self, trainer: "MockTrainer") -> None: if not trainer.clock.is_due(self.config.on_batch_end_interval): return + + # We verify that the callback is always called before the clock is updated + assert trainer.clock.step // 3 <= self.step_end_count + self.step_end_count += 1 with scoped_seed(self.config.on_batch_end_seed): self.step_end_random_int = random.randint(0, 100)