fix clock

This commit is contained in:
limiteinductive 2024-04-24 16:42:30 +00:00 committed by Benjamin Trom
parent 44760ac19f
commit 603c8abb1e
3 changed files with 17 additions and 2 deletions

View file

@ -155,7 +155,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.log(f"Iteration {trainer.clock.iteration} started.") self.log(f"Iteration {trainer.clock.iteration} started.")
self.log(f"Step {trainer.clock.step} started.") self.log(f"Step {trainer.clock.step} started.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.") self.log(f"Step {trainer.clock.step} ended.")
trainer.clock.step += 1 trainer.clock.step += 1
trainer.clock.num_batches_processed += 1 trainer.clock.num_batches_processed += 1

View file

@ -401,9 +401,20 @@ class Trainer(Generic[ConfigType, Batch], ABC):
elif mode == "eval": elif mode == "eval":
item.model.eval() item.model.eval()
def _run_event(self, callback: Callback[Any], event_name: str) -> None:
getattr(callback, event_name)(self)
def _call_callbacks(self, event_name: str) -> None: def _call_callbacks(self, event_name: str) -> None:
if event_name.endswith("_begin"):
self._run_event(self.clock, event_name)
for callback in self.callbacks.values(): for callback in self.callbacks.values():
getattr(callback, event_name)(self) if callback == self.clock:
continue
self._run_event(callback, event_name)
if event_name.endswith("_end"):
self._run_event(self.clock, event_name)
def _load_callbacks(self) -> None: def _load_callbacks(self) -> None:
for name, config in self.config: for name, config in self.config:

View file

@ -94,6 +94,10 @@ class MockCallback(Callback["MockTrainer"]):
def on_step_end(self, trainer: "MockTrainer") -> None: def on_step_end(self, trainer: "MockTrainer") -> None:
if not trainer.clock.is_due(self.config.on_batch_end_interval): if not trainer.clock.is_due(self.config.on_batch_end_interval):
return return
# We verify that the callback is always called before the clock is updated
assert trainer.clock.step // 3 <= self.step_end_count
self.step_end_count += 1 self.step_end_count += 1
with scoped_seed(self.config.on_batch_end_seed): with scoped_seed(self.config.on_batch_end_seed):
self.step_end_random_int = random.randint(0, 100) self.step_end_random_int = random.randint(0, 100)