update test_trainer.py

This commit is contained in:
Benjamin Trom 2024-04-25 15:07:40 +02:00
parent 603c8abb1e
commit 1db0845db2

View file

@ -13,6 +13,7 @@ from torch.optim import SGD
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import (
Epoch,
Iteration,
@ -57,6 +58,10 @@ class MockCallbackConfig(CallbackConfig):
class MockConfig(BaseConfig):
# we register the `early_callback` before the `clock` callback to test the callback ordering
early_callback: CallbackConfig = CallbackConfig()
clock: ClockConfig = ClockConfig()
mock_model: MockModelConfig
mock_callback: MockCallbackConfig
@ -95,7 +100,7 @@ class MockCallback(Callback["MockTrainer"]):
if not trainer.clock.is_due(self.config.on_batch_end_interval):
return
# We verify that the callback is always called before the clock is updated
# We verify that this callback is always called before the clock is updated (see `_call_callbacks` in trainer.py)
assert trainer.clock.step // 3 <= self.step_end_count
self.step_end_count += 1
@ -103,6 +108,15 @@ class MockCallback(Callback["MockTrainer"]):
self.step_end_random_int = random.randint(0, 100)
class EarlyMockCallback(Callback["MockTrainer"]):
"""
A callback that will be registered before the Clock callback to test the callback ordering.
"""
def on_train_begin(self, trainer: "MockTrainer") -> None:
assert trainer.clock.start_time is not None, "Clock callback should have been called before this callback."
class MockTrainer(Trainer[MockConfig, MockBatch]):
step_counter: int = 0
model_registration_counter: int = 0
@ -120,6 +134,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
targets=torch.cat([b.targets for b in batch]),
)
@register_callback()
def early_callback(self, config: CallbackConfig) -> EarlyMockCallback:
return EarlyMockCallback()
@register_callback()
def mock_callback(self, config: MockCallbackConfig) -> MockCallback:
return MockCallback(config)
@ -276,7 +294,7 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
# Check that the callback skips every other iteration
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1
# Check that the random seed was set
assert mock_trainer.mock_callback.optimizer_step_random_int == 93