TrainerClock: assert dataset_length >= batch_size

This commit is contained in:
Pierre Colle 2024-04-11 17:43:07 +00:00 committed by Colle
parent 0ac290f67d
commit 64692c3b5b
2 changed files with 28 additions and 0 deletions

View file

@ -29,6 +29,10 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
lr_scheduler_interval: TimeValue, lr_scheduler_interval: TimeValue,
verbose: bool = True, verbose: bool = True,
) -> None: ) -> None:
assert batch_size > 0, "Batch size must be greater than 0."
assert (
dataset_length >= batch_size
), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})."
self.dataset_length = dataset_length self.dataset_length = dataset_length
self.batch_size = batch_size self.batch_size = batch_size
self.training_duration = training_duration self.training_duration = training_duration

View file

@ -137,6 +137,30 @@ def training_clock() -> TrainingClock:
) )
def test_small_dataset_error():
with pytest.raises(AssertionError):
TrainingClock(
dataset_length=3,
batch_size=10,
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
)
def test_zero_batch_size_error():
with pytest.raises(AssertionError):
TrainingClock(
dataset_length=3,
batch_size=0,
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
)
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10 assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20 assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20