mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
update test_trainer.py
This commit is contained in:
parent
603c8abb1e
commit
1db0845db2
|
@ -13,6 +13,7 @@ from torch.optim import SGD
|
|||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import norm
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import (
|
||||
Epoch,
|
||||
Iteration,
|
||||
|
@ -57,6 +58,10 @@ class MockCallbackConfig(CallbackConfig):
|
|||
|
||||
|
||||
class MockConfig(BaseConfig):
|
||||
# we register the `early_callback` before the `clock` callback to test the callback ordering
|
||||
early_callback: CallbackConfig = CallbackConfig()
|
||||
clock: ClockConfig = ClockConfig()
|
||||
|
||||
mock_model: MockModelConfig
|
||||
mock_callback: MockCallbackConfig
|
||||
|
||||
|
@ -95,7 +100,7 @@ class MockCallback(Callback["MockTrainer"]):
|
|||
if not trainer.clock.is_due(self.config.on_batch_end_interval):
|
||||
return
|
||||
|
||||
# We verify that the callback is always called before the clock is updated
|
||||
# We verify that this callback is always called before the clock is updated (see `_call_callbacks` in trainer.py)
|
||||
assert trainer.clock.step // 3 <= self.step_end_count
|
||||
|
||||
self.step_end_count += 1
|
||||
|
@ -103,6 +108,15 @@ class MockCallback(Callback["MockTrainer"]):
|
|||
self.step_end_random_int = random.randint(0, 100)
|
||||
|
||||
|
||||
class EarlyMockCallback(Callback["MockTrainer"]):
|
||||
"""
|
||||
A callback that will be registered before the Clock callback to test the callback ordering.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, trainer: "MockTrainer") -> None:
|
||||
assert trainer.clock.start_time is not None, "Clock callback should have been called before this callback."
|
||||
|
||||
|
||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||
step_counter: int = 0
|
||||
model_registration_counter: int = 0
|
||||
|
@ -120,6 +134,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
|||
targets=torch.cat([b.targets for b in batch]),
|
||||
)
|
||||
|
||||
@register_callback()
|
||||
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
|
||||
return EarlyMockCallback()
|
||||
|
||||
@register_callback()
|
||||
def mock_callback(self, config: MockCallbackConfig) -> MockCallback:
|
||||
return MockCallback(config)
|
||||
|
@ -276,7 +294,7 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
|
|||
|
||||
# Check that the callback skips every other iteration
|
||||
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
|
||||
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3
|
||||
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1
|
||||
|
||||
# Check that the random seed was set
|
||||
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
|
||||
|
|
Loading…
Reference in a new issue