mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
TrainerClock: assert dataset_length >= batch_size
This commit is contained in:
parent
0ac290f67d
commit
64692c3b5b
|
@ -29,6 +29,10 @@ class TrainingClock(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
lr_scheduler_interval: TimeValue,
|
lr_scheduler_interval: TimeValue,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert batch_size > 0, "Batch size must be greater than 0."
|
||||||
|
assert (
|
||||||
|
dataset_length >= batch_size
|
||||||
|
), f"Dataset length ({dataset_length}) must be greater than batch_size ({batch_size})."
|
||||||
self.dataset_length = dataset_length
|
self.dataset_length = dataset_length
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.training_duration = training_duration
|
self.training_duration = training_duration
|
||||||
|
|
|
@ -137,6 +137,30 @@ def training_clock() -> TrainingClock:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_small_dataset_error():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
TrainingClock(
|
||||||
|
dataset_length=3,
|
||||||
|
batch_size=10,
|
||||||
|
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
||||||
|
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
|
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
|
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_batch_size_error():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
TrainingClock(
|
||||||
|
dataset_length=3,
|
||||||
|
batch_size=0,
|
||||||
|
training_duration=TimeValue(number=5, unit=TimeUnit.EPOCH),
|
||||||
|
gradient_accumulation=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
|
evaluation_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
|
lr_scheduler_interval=TimeValue(number=1, unit=TimeUnit.EPOCH),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None:
|
||||||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10
|
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10
|
||||||
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20
|
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20
|
||||||
|
|
Loading…
Reference in a new issue