mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix clock
This commit is contained in:
parent
44760ac19f
commit
603c8abb1e
|
@ -155,7 +155,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||
self.log(f"Step {trainer.clock.step} started.")
|
||||
|
||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
self.log(f"Step {trainer.clock.step} ended.")
|
||||
trainer.clock.step += 1
|
||||
trainer.clock.num_batches_processed += 1
|
||||
|
|
|
@ -401,9 +401,20 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
elif mode == "eval":
|
||||
item.model.eval()
|
||||
|
||||
def _run_event(self, callback: Callback[Any], event_name: str) -> None:
|
||||
getattr(callback, event_name)(self)
|
||||
|
||||
def _call_callbacks(self, event_name: str) -> None:
|
||||
if event_name.endswith("_begin"):
|
||||
self._run_event(self.clock, event_name)
|
||||
|
||||
for callback in self.callbacks.values():
|
||||
getattr(callback, event_name)(self)
|
||||
if callback == self.clock:
|
||||
continue
|
||||
self._run_event(callback, event_name)
|
||||
|
||||
if event_name.endswith("_end"):
|
||||
self._run_event(self.clock, event_name)
|
||||
|
||||
def _load_callbacks(self) -> None:
|
||||
for name, config in self.config:
|
||||
|
|
|
@ -94,6 +94,10 @@ class MockCallback(Callback["MockTrainer"]):
|
|||
def on_step_end(self, trainer: "MockTrainer") -> None:
|
||||
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
||||
return
|
||||
|
||||
# We verify that the callback is always called before the clock is updated
|
||||
assert trainer.clock.step // 3 <= self.step_end_count
|
||||
|
||||
self.step_end_count += 1
|
||||
with scoped_seed(self.config.on_batch_end_seed):
|
||||
self.step_end_random_int = random.randint(0, 100)
|
||||
|
|
Loading…
Reference in a new issue