From 64692c3b5beeb808ef1609f901f3a124b9cae0ed Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Thu, 11 Apr 2024 17:43:07 +0000 Subject: [PATCH] TrainerClock: assert dataset_length >= batch_size --- src/refiners/training_utils/clock.py | 4 ++++ tests/training_utils/test_trainer.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 0d69cd5..07ab952 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -29,6 +29,10 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]): lr_scheduler_interval: TimeValue, verbose: bool = True, ) -> None: + assert batch_size > 0, "Batch size must be greater than 0." + assert ( + dataset_length >= batch_size + ), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})." self.dataset_length = dataset_length self.batch_size = batch_size self.training_duration = training_duration diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 19fa462..940445d 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -137,6 +137,30 @@ def training_clock() -> TrainingClock: ) +def test_small_dataset_error(): + with pytest.raises(AssertionError): + TrainingClock( + dataset_length=3, + batch_size=10, + training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), + gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), + evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + ) + + +def test_zero_batch_size_error(): + with pytest.raises(AssertionError): + TrainingClock( + dataset_length=3, + batch_size=0, + training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH), + gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH), + evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH), + ) + + def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10 assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20