diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index b5471a2..2ce0984 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -31,9 +31,9 @@ class Callback(Generic[T]): def on_epoch_end(self, trainer: T) -> None: ... - def on_batch_begin(self, trainer: T) -> None: ... + def on_step_begin(self, trainer: T) -> None: ... - def on_batch_end(self, trainer: T) -> None: ... + def on_step_end(self, trainer: T) -> None: ... def on_backward_begin(self, trainer: T) -> None: ... diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 559ee1b..d206456 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -156,7 +156,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): trainer.clock.epoch += 1 trainer.clock.num_batches_processed = 0 - def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: if self.num_minibatches_processed == 0: self.log(f"Iteration {trainer.clock.iteration} started.") self.log(f"Step {trainer.clock.step} started.") diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 10cbf48..997f097 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -381,9 +381,9 @@ class Trainer(Generic[ConfigType, Batch], ABC): for batch in self.dataloader: if self.clock.done: break - self._call_callbacks(event_name="on_batch_begin") + self._call_callbacks(event_name="on_step_begin") self.step(batch=batch) - self._call_callbacks(event_name="on_batch_end") + self._call_callbacks(event_name="on_step_end") @staticmethod def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int: diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index fd62312..fdfef5b 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -78,9 +78,9 @@ class MockCallback(Callback["MockTrainer"]): def __init__(self, config: MockCallbackConfig) -> None: self.config = config self.optimizer_step_count = 0 - self.batch_end_count = 0 + self.step_end_count = 0 self.optimizer_step_random_int: int | None = None - self.batch_end_random_int: int | None = None + self.step_end_random_int: int | None = None def on_init_begin(self, trainer: "MockTrainer") -> None: pass @@ -91,12 +91,12 @@ class MockCallback(Callback["MockTrainer"]): self.optimizer_step_count += 1 self.optimizer_step_random_int = random.randint(0, 100) - def on_batch_end(self, trainer: "MockTrainer") -> None: + def on_step_end(self, trainer: "MockTrainer") -> None: if not trainer.clock.is_due(self.config.on_batch_end_interval): return - self.batch_end_count += 1 + self.step_end_count += 1 with scoped_seed(self.config.on_batch_end_seed): - self.batch_end_random_int = random.randint(0, 100) + self.step_end_random_int = random.randint(0, 100) class MockTrainer(Trainer[MockConfig, MockBatch]): @@ -282,11 +282,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None: # Check that the callback skips every other iteration assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 - assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3 + assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 # Check that the random seed was set assert mock_trainer.mock_callback.optimizer_step_random_int == 93 - assert mock_trainer.mock_callback.batch_end_random_int == 81 + assert mock_trainer.mock_callback.step_end_random_int == 81 def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None: