mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
batch to step
This commit is contained in:
parent
b7bb8bba80
commit
061d44888f
|
@ -31,9 +31,9 @@ class Callback(Generic[T]):
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: T) -> None: ...
|
def on_epoch_end(self, trainer: T) -> None: ...
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: T) -> None: ...
|
def on_step_begin(self, trainer: T) -> None: ...
|
||||||
|
|
||||||
def on_batch_end(self, trainer: T) -> None: ...
|
def on_step_end(self, trainer: T) -> None: ...
|
||||||
|
|
||||||
def on_backward_begin(self, trainer: T) -> None: ...
|
def on_backward_begin(self, trainer: T) -> None: ...
|
||||||
|
|
||||||
|
|
|
@ -156,7 +156,7 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
trainer.clock.epoch += 1
|
trainer.clock.epoch += 1
|
||||||
trainer.clock.num_batches_processed = 0
|
trainer.clock.num_batches_processed = 0
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
if self.num_minibatches_processed == 0:
|
if self.num_minibatches_processed == 0:
|
||||||
self.log(f"Iteration {trainer.clock.iteration} started.")
|
self.log(f"Iteration {trainer.clock.iteration} started.")
|
||||||
self.log(f"Step {trainer.clock.step} started.")
|
self.log(f"Step {trainer.clock.step} started.")
|
||||||
|
|
|
@ -381,9 +381,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
for batch in self.dataloader:
|
for batch in self.dataloader:
|
||||||
if self.clock.done:
|
if self.clock.done:
|
||||||
break
|
break
|
||||||
self._call_callbacks(event_name="on_batch_begin")
|
self._call_callbacks(event_name="on_step_begin")
|
||||||
self.step(batch=batch)
|
self.step(batch=batch)
|
||||||
self._call_callbacks(event_name="on_batch_end")
|
self._call_callbacks(event_name="on_step_end")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int:
|
def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int:
|
||||||
|
|
|
@ -78,9 +78,9 @@ class MockCallback(Callback["MockTrainer"]):
|
||||||
def __init__(self, config: MockCallbackConfig) -> None:
|
def __init__(self, config: MockCallbackConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.optimizer_step_count = 0
|
self.optimizer_step_count = 0
|
||||||
self.batch_end_count = 0
|
self.step_end_count = 0
|
||||||
self.optimizer_step_random_int: int | None = None
|
self.optimizer_step_random_int: int | None = None
|
||||||
self.batch_end_random_int: int | None = None
|
self.step_end_random_int: int | None = None
|
||||||
|
|
||||||
def on_init_begin(self, trainer: "MockTrainer") -> None:
|
def on_init_begin(self, trainer: "MockTrainer") -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -91,12 +91,12 @@ class MockCallback(Callback["MockTrainer"]):
|
||||||
self.optimizer_step_count += 1
|
self.optimizer_step_count += 1
|
||||||
self.optimizer_step_random_int = random.randint(0, 100)
|
self.optimizer_step_random_int = random.randint(0, 100)
|
||||||
|
|
||||||
def on_batch_end(self, trainer: "MockTrainer") -> None:
|
def on_step_end(self, trainer: "MockTrainer") -> None:
|
||||||
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
||||||
return
|
return
|
||||||
self.batch_end_count += 1
|
self.step_end_count += 1
|
||||||
with scoped_seed(self.config.on_batch_end_seed):
|
with scoped_seed(self.config.on_batch_end_seed):
|
||||||
self.batch_end_random_int = random.randint(0, 100)
|
self.step_end_random_int = random.randint(0, 100)
|
||||||
|
|
||||||
|
|
||||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
|
@ -282,11 +282,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
|
||||||
|
|
||||||
# Check that the callback skips every other iteration
|
# Check that the callback skips every other iteration
|
||||||
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
||||||
assert mock_trainer.mock_callback.batch_end_count == mock_trainer.clock.step // 3
|
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3
|
||||||
|
|
||||||
# Check that the random seed was set
|
# Check that the random seed was set
|
||||||
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
||||||
assert mock_trainer.mock_callback.batch_end_random_int == 81
|
assert mock_trainer.mock_callback.step_end_random_int == 81
|
||||||
|
|
||||||
|
|
||||||
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
||||||
|
|
Loading…
Reference in a new issue