mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
less than 1 epoch training duration
This commit is contained in:
parent
41508e0865
commit
f4aa0271b8
|
@ -522,6 +522,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
def epoch(self) -> None:
|
||||
"""Perform a single epoch."""
|
||||
for batch in self.dataloader:
|
||||
if self.clock.done:
|
||||
break
|
||||
self._call_callbacks(event_name="on_batch_begin")
|
||||
self.step(batch=batch)
|
||||
self._call_callbacks(event_name="on_batch_end")
|
||||
|
|
|
@ -80,6 +80,13 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
|
|||
return MockTrainer(config=mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
|
||||
mock_config_short = mock_config.copy()
|
||||
mock_config_short.training.duration = {"number": 3, "unit": TimeUnit.STEP}
|
||||
return MockTrainer(config=mock_config_short)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model() -> fl.Chain:
|
||||
return MockModel()
|
||||
|
@ -176,6 +183,18 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
|||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
||||
|
||||
|
||||
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
||||
clock = mock_trainer_short.clock
|
||||
config = mock_trainer_short.config
|
||||
|
||||
assert mock_trainer_short.step_counter == 0
|
||||
assert mock_trainer_short.clock.epoch == 0
|
||||
|
||||
mock_trainer_short.train()
|
||||
|
||||
assert clock.step == config.training.duration["number"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def warmup_scheduler():
|
||||
optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1)
|
||||
|
|
Loading…
Reference in a new issue