fix clock

This commit is contained in:
limiteinductive 2024-04-24 16:42:30 +00:00 committed by Benjamin Trom
parent 44760ac19f
commit 603c8abb1e
3 changed files with 17 additions and 2 deletions

View file

@ -155,7 +155,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.log(f"Iteration {trainer.clock.iteration} started.")
self.log(f"Step {trainer.clock.step} started.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.")
trainer.clock.step += 1
trainer.clock.num_batches_processed += 1

View file

@ -401,10 +401,21 @@ class Trainer(Generic[ConfigType, Batch], ABC):
elif mode == "eval":
item.model.eval()
def _call_callbacks(self, event_name: str) -> None:
for callback in self.callbacks.values():
def _run_event(self, callback: Callback[Any], event_name: str) -> None:
getattr(callback, event_name)(self)
def _call_callbacks(self, event_name: str) -> None:
if event_name.endswith("_begin"):
self._run_event(self.clock, event_name)
for callback in self.callbacks.values():
if callback == self.clock:
continue
self._run_event(callback, event_name)
if event_name.endswith("_end"):
self._run_event(self.clock, event_name)
def _load_callbacks(self) -> None:
for name, config in self.config:
if not isinstance(config, CallbackConfig):

View file

@ -94,6 +94,10 @@ class MockCallback(Callback["MockTrainer"]):
def on_step_end(self, trainer: "MockTrainer") -> None:
if not trainer.clock.is_due(self.config.on_batch_end_interval):
return
# We verify that the callback is always called before the clock is updated
assert trainer.clock.step // 3 <= self.step_end_count
self.step_end_count += 1
with scoped_seed(self.config.on_batch_end_seed):
self.step_end_random_int = random.randint(0, 100)