mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
batch to step
This commit is contained in:
parent
b7bb8bba80
commit
061d44888f
|
@ -31,9 +31,9 @@ class Callback(Generic[T]):
|
|||
|
||||
def on_epoch_end(self, trainer: T) -> None: ...
|
||||
|
||||
def on_batch_begin(self, trainer: T) -> None: ...
|
||||
def on_step_begin(self, trainer: T) -> None: ...
|
||||
|
||||
def on_batch_end(self, trainer: T) -> None: ...
|
||||
def on_step_end(self, trainer: T) -> None: ...
|
||||
|
||||
def on_backward_begin(self, trainer: T) -> None: ...
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
trainer.clock.epoch += 1
|
||||
trainer.clock.num_batches_processed = 0
|
||||
|
||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||
if self.num_minibatches_processed == 0:
|
||||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||
self.log(f"Step {trainer.clock.step} started.")
|
||||
|
|
|
@ -381,9 +381,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
for batch in self.dataloader:
|
||||
if self.clock.done:
|
||||
break
|
||||
self._call_callbacks(event_name="on_batch_begin")
|
||||
self._call_callbacks(event_name="on_step_begin")
|
||||
self.step(batch=batch)
|
||||
self._call_callbacks(event_name="on_batch_end")
|
||||
self._call_callbacks(event_name="on_step_end")
|
||||
|
||||
@staticmethod
|
||||
def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int:
|
||||
|
|
|
@ -78,9 +78,9 @@ class MockCallback(Callback["MockTrainer"]):
|
|||
def __init__(self, config: MockCallbackConfig) -> None:
|
||||
self.config = config
|
||||
self.optimizer_step_count = 0
|
||||
self.batch_end_count = 0
|
||||
self.step_end_count = 0
|
||||
self.optimizer_step_random_int: int | None = None
|
||||
self.batch_end_random_int: int | None = None
|
||||
self.step_end_random_int: int | None = None
|
||||
|
||||
def on_init_begin(self, trainer: "MockTrainer") -> None:
|
||||
pass
|
||||
|
@ -91,12 +91,12 @@ class MockCallback(Callback["MockTrainer"]):
|
|||
self.optimizer_step_count += 1
|
||||
self.optimizer_step_random_int = random.randint(0, 100)
|
||||
|
||||
def on_batch_end(self, trainer: "MockTrainer") -> None:
|
||||
def on_step_end(self, trainer: "MockTrainer") -> None:
|
||||
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
||||
return
|
||||
self.batch_end_count += 1
|
||||
self.step_end_count += 1
|
||||
with scoped_seed(self.config.on_batch_end_seed):
|
||||
self.batch_end_random_int = random.randint(0, 100)
|
||||
self.step_end_random_int = random.randint(0, 100)
|
||||
|
||||
|
||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||
|
@ -282,11 +282,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
|
|||
|
||||
# Check that the callback skips every other iteration
|
||||
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
||||
assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3
|
||||
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3
|
||||
|
||||
# Check that the random seed was set
|
||||
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
||||
assert mock_trainer.mock_callback.batch_end_random_int == 81
|
||||
assert mock_trainer.mock_callback.step_end_random_int == 81
|
||||
|
||||
|
||||
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
||||
|
|
Loading…
Reference in a new issue