less than 1 epoch training duration

This commit is contained in:
Colle 2024-02-08 19:20:31 +01:00 committed by GitHub
parent 41508e0865
commit f4aa0271b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 0 deletions

View file

@ -522,6 +522,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def epoch(self) -> None:
"""Perform a single epoch."""
for batch in self.dataloader:
if self.clock.done:
break
self._call_callbacks(event_name="on_batch_begin")
self.step(batch=batch)
self._call_callbacks(event_name="on_batch_end")

View file

@ -80,6 +80,13 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
return MockTrainer(config=mock_config)
@pytest.fixture
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
mock_config_short = mock_config.copy()
mock_config_short.training.duration = {"number": 3, "unit": TimeUnit.STEP}
return MockTrainer(config=mock_config_short)
@pytest.fixture
def mock_model() -> fl.Chain:
return MockModel()
@ -176,6 +183,18 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
assert mock_trainer.step_counter == mock_trainer.clock.step
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
clock = mock_trainer_short.clock
config = mock_trainer_short.config
assert mock_trainer_short.step_counter == 0
assert mock_trainer_short.clock.epoch == 0
mock_trainer_short.train()
assert clock.step == config.training.duration["number"]
@pytest.fixture
def warmup_scheduler():
optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1)