batch to step

This commit is contained in:
limiteinductive 2024-04-24 16:27:43 +00:00 committed by Benjamin Trom
parent b7bb8bba80
commit 061d44888f
4 changed files with 12 additions and 12 deletions

View file

@ -31,9 +31,9 @@ class Callback(Generic[T]):
def on_epoch_end(self, trainer: T) -> None: ...
def on_batch_begin(self, trainer: T) -> None: ...
def on_step_begin(self, trainer: T) -> None: ...
def on_batch_end(self, trainer: T) -> None: ...
def on_step_end(self, trainer: T) -> None: ...
def on_backward_begin(self, trainer: T) -> None: ...

View file

@ -156,7 +156,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
trainer.clock.epoch += 1
trainer.clock.num_batches_processed = 0
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
if self.num_minibatches_processed == 0:
self.log(f"Iteration {trainer.clock.iteration} started.")
self.log(f"Step {trainer.clock.step} started.")

View file

@ -381,9 +381,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
for batch in self.dataloader:
if self.clock.done:
break
self._call_callbacks(event_name="on_batch_begin")
self._call_callbacks(event_name="on_step_begin")
self.step(batch=batch)
self._call_callbacks(event_name="on_batch_end")
self._call_callbacks(event_name="on_step_end")
@staticmethod
def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int:

View file

@ -78,9 +78,9 @@ class MockCallback(Callback["MockTrainer"]):
def __init__(self, config: MockCallbackConfig) -> None:
self.config = config
self.optimizer_step_count = 0
self.batch_end_count = 0
self.step_end_count = 0
self.optimizer_step_random_int: int | None = None
self.batch_end_random_int: int | None = None
self.step_end_random_int: int | None = None
def on_init_begin(self, trainer: "MockTrainer") -> None:
pass
@ -91,12 +91,12 @@ class MockCallback(Callback["MockTrainer"]):
self.optimizer_step_count += 1
self.optimizer_step_random_int = random.randint(0, 100)
def on_batch_end(self, trainer: "MockTrainer") -> None:
def on_step_end(self, trainer: "MockTrainer") -> None:
if not trainer.clock.is_due(self.config.on_batch_end_interval):
return
self.batch_end_count += 1
self.step_end_count += 1
with scoped_seed(self.config.on_batch_end_seed):
self.batch_end_random_int = random.randint(0, 100)
self.step_end_random_int = random.randint(0, 100)
class MockTrainer(Trainer[MockConfig, MockBatch]):
@ -282,11 +282,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
# Check that the callback skips every other iteration
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3
# Check that the random seed was set
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
assert mock_trainer.mock_callback.batch_end_random_int == 81
assert mock_trainer.mock_callback.step_end_random_int == 81
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: