Fix clock step inconsistencies on batch end

This commit is contained in:
hugojarkoff 2024-04-05 13:44:59 +00:00 committed by hugojarkoff
parent 09af570b23
commit bbb46e3fc7
2 changed files with 2 additions and 4 deletions

View file

@ -179,10 +179,8 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
self.log(f"Iteration {trainer.clock.iteration} started.")
self.log(f"Step {trainer.clock.step} started.")
def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.")
def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {trainer.clock.step} ended.")
trainer.clock.step += 1
trainer.clock.num_batches_processed += 1
trainer.clock.num_minibatches_processed += 1

View file

@ -66,7 +66,7 @@ def scoped_seed(seed: int | Callable[..., int] | None = None) -> Callable[..., C
actual_seed = seed(*args) if callable(seed) else seed
seed_everything(seed=actual_seed)
result = func(*args, **kwargs)
logger.debug(f"Restoring previous seed state")
logger.trace(f"Restoring previous seed state")
random.setstate(random_state)
np.random.set_state(numpy_state)
torch.set_rng_state(torch_state)