mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
less than 1 epoch training duration
This commit is contained in:
parent
41508e0865
commit
f4aa0271b8
|
@ -522,6 +522,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
def epoch(self) -> None:
|
def epoch(self) -> None:
|
||||||
"""Perform a single epoch."""
|
"""Perform a single epoch."""
|
||||||
for batch in self.dataloader:
|
for batch in self.dataloader:
|
||||||
|
if self.clock.done:
|
||||||
|
break
|
||||||
self._call_callbacks(event_name="on_batch_begin")
|
self._call_callbacks(event_name="on_batch_begin")
|
||||||
self.step(batch=batch)
|
self.step(batch=batch)
|
||||||
self._call_callbacks(event_name="on_batch_end")
|
self._call_callbacks(event_name="on_batch_end")
|
||||||
|
|
|
@ -80,6 +80,13 @@ def mock_trainer(mock_config: MockConfig) -> MockTrainer:
|
||||||
return MockTrainer(config=mock_config)
|
return MockTrainer(config=mock_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trainer_short(mock_config: MockConfig) -> MockTrainer:
|
||||||
|
mock_config_short = mock_config.copy()
|
||||||
|
mock_config_short.training.duration = {"number": 3, "unit": TimeUnit.STEP}
|
||||||
|
return MockTrainer(config=mock_config_short)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_model() -> fl.Chain:
|
def mock_model() -> fl.Chain:
|
||||||
return MockModel()
|
return MockModel()
|
||||||
|
@ -176,6 +183,18 @@ def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
assert mock_trainer.step_counter == mock_trainer.clock.step
|
assert mock_trainer.step_counter == mock_trainer.clock.step
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_short_cycle(mock_trainer_short: MockTrainer) -> None:
|
||||||
|
clock = mock_trainer_short.clock
|
||||||
|
config = mock_trainer_short.config
|
||||||
|
|
||||||
|
assert mock_trainer_short.step_counter == 0
|
||||||
|
assert mock_trainer_short.clock.epoch == 0
|
||||||
|
|
||||||
|
mock_trainer_short.train()
|
||||||
|
|
||||||
|
assert clock.step == config.training.duration["number"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def warmup_scheduler():
|
def warmup_scheduler():
|
||||||
optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1)
|
optimizer = SGD([nn.Parameter(torch.randn(2, 2), requires_grad=True)], lr=0.1)
|
||||||
|
|
Loading…
Reference in a new issue