mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
TrainerClock: assert dataset_length >= batch_size
This commit is contained in:
parent
0ac290f67d
commit
64692c3b5b
|
@ -29,6 +29,10 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
|||
lr_scheduler_interval: TimeValue,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
assert batch_size > 0, "Batch size must be greater than 0."
|
||||
assert (
|
||||
dataset_length >= batch_size
|
||||
), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})."
|
||||
self.dataset_length = dataset_length
|
||||
self.batch_size = batch_size
|
||||
self.training_duration = training_duration
|
||||
|
|
|
@ -137,6 +137,30 @@ def training_clock() -> TrainingClock:
|
|||
)
|
||||
|
||||
|
||||
def test_small_dataset_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
dataset_length=3,
|
||||
batch_size=10,
|
||||
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
||||
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||
)
|
||||
|
||||
|
||||
def test_zero_batch_size_error():
|
||||
with pytest.raises(AssertionError):
|
||||
TrainingClock(
|
||||
dataset_length=3,
|
||||
batch_size=0,
|
||||
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
||||
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||
)
|
||||
|
||||
|
||||
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
||||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10
|
||||
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20
|
||||
|
|
Loading…
Reference in a new issue