mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
fix clock
This commit is contained in:
parent
44760ac19f
commit
603c8abb1e
|
@ -155,7 +155,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||||
self.log(f"Step {trainer.clock.step} started.")
|
self.log(f"Step {trainer.clock.step} started.")
|
||||||
|
|
||||||
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
self.log(f"Step {trainer.clock.step} ended.")
|
self.log(f"Step {trainer.clock.step} ended.")
|
||||||
trainer.clock.step += 1
|
trainer.clock.step += 1
|
||||||
trainer.clock.num_batches_processed += 1
|
trainer.clock.num_batches_processed += 1
|
||||||
|
|
|
@ -401,9 +401,20 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
elif mode == "eval":
|
elif mode == "eval":
|
||||||
item.model.eval()
|
item.model.eval()
|
||||||
|
|
||||||
|
def _run_event(self, callback: Callback[Any], event_name: str) -> None:
|
||||||
|
getattr(callback, event_name)(self)
|
||||||
|
|
||||||
def _call_callbacks(self, event_name: str) -> None:
|
def _call_callbacks(self, event_name: str) -> None:
|
||||||
|
if event_name.endswith("_begin"):
|
||||||
|
self._run_event(self.clock, event_name)
|
||||||
|
|
||||||
for callback in self.callbacks.values():
|
for callback in self.callbacks.values():
|
||||||
getattr(callback, event_name)(self)
|
if callback == self.clock:
|
||||||
|
continue
|
||||||
|
self._run_event(callback, event_name)
|
||||||
|
|
||||||
|
if event_name.endswith("_end"):
|
||||||
|
self._run_event(self.clock, event_name)
|
||||||
|
|
||||||
def _load_callbacks(self) -> None:
|
def _load_callbacks(self) -> None:
|
||||||
for name, config in self.config:
|
for name, config in self.config:
|
||||||
|
|
|
@ -94,6 +94,10 @@ class MockCallback(Callback["MockTrainer"]):
|
||||||
def on_step_end(self, trainer: "MockTrainer") -> None:
|
def on_step_end(self, trainer: "MockTrainer") -> None:
|
||||||
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# We verify that the callback is always called before the clock is updated
|
||||||
|
assert trainer.clock.step // 3 <= self.step_end_count
|
||||||
|
|
||||||
self.step_end_count += 1
|
self.step_end_count += 1
|
||||||
with scoped_seed(self.config.on_batch_end_seed):
|
with scoped_seed(self.config.on_batch_end_seed):
|
||||||
self.step_end_random_int = random.randint(0, 100)
|
self.step_end_random_int = random.randint(0, 100)
|
||||||
|
|
Loading…
Reference in a new issue