From 1db0845db204726700f81a391c3400e046780cb1 Mon Sep 17 00:00:00 2001 From: Benjamin Trom Date: Thu, 25 Apr 2024 15:07:40 +0200 Subject: [PATCH] update test_trainer.py --- tests/training_utils/test_trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index a9c91a1..f929814 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -13,6 +13,7 @@ from torch.optim import SGD from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm from refiners.training_utils.callback import Callback, CallbackConfig +from refiners.training_utils.clock import ClockConfig from refiners.training_utils.common import ( Epoch, Iteration, @@ -57,6 +58,10 @@ class MockCallbackConfig(CallbackConfig): class MockConfig(BaseConfig): + # we register the `early_callback` before the `clock` callback to test the callback ordering + early_callback: CallbackConfig = CallbackConfig() + clock: ClockConfig = ClockConfig() + mock_model: MockModelConfig mock_callback: MockCallbackConfig @@ -95,7 +100,7 @@ class MockCallback(Callback["MockTrainer"]): if not trainer.clock.is_due(self.config.on_batch_end_interval): return - # We verify that the callback is always called before the clock is updated + # We verify that this callback is always called before the clock is updated (see `_call_callbacks` in trainer.py) assert trainer.clock.step // 3 <= self.step_end_count self.step_end_count += 1 @@ -103,6 +108,15 @@ class MockCallback(Callback["MockTrainer"]): self.step_end_random_int = random.randint(0, 100) +class EarlyMockCallback(Callback["MockTrainer"]): + """ + A callback that will be registered before the Clock callback to test the callback ordering. + """ + + def on_train_begin(self, trainer: "MockTrainer") -> None: + assert trainer.clock.start_time is not None, "Clock callback should have been called before this callback." + + class MockTrainer(Trainer[MockConfig, MockBatch]): step_counter: int = 0 model_registration_counter: int = 0 @@ -120,6 +134,10 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): targets=torch.cat([b.targets for b in batch]), ) + @register_callback() + def early_callback(self, config: CallbackConfig) -> EarlyMockCallback: + return EarlyMockCallback() + @register_callback() def mock_callback(self, config: MockCallbackConfig) -> MockCallback: return MockCallback(config) @@ -276,7 +294,7 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None: # Check that the callback skips every other iteration assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 - assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1 # Check that the random seed was set assert mock_trainer.mock_callback.optimizer_step_random_int == 93