From ab506b4db25dd0a641e569aaa599b247e6f7056d Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Tue, 13 Feb 2024 06:29:55 +0000 Subject: [PATCH] fix bug that was causing double registration --- src/refiners/training_utils/trainer.py | 4 ++-- tests/training_utils/test_trainer.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 6b98d6f..eb6c25a 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -109,7 +109,7 @@ def register_model(): name=name, config=config, model=model, learnable_parameters=learnable_parameters ) setattr(self, name, self.models[name].model) - return func(self, config) + return model return wrapper # type: ignore @@ -129,7 +129,7 @@ def register_callback(): callback = func(self, config) self.callbacks[name] = callback setattr(self, name, callback) - return func(self, config) + return callback return wrapper # type: ignore diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index bf72a33..68d5a4b 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -50,6 +50,7 @@ class MockModel(fl.Chain): class MockTrainer(Trainer[MockConfig, MockBatch]): step_counter: int = 0 + model_registration_counter: int = 0 @property def dataset_length(self) -> int: @@ -69,6 +70,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): model = MockModel() if config.use_activation: model.add_activation() + + self.model_registration_counter += 1 return model def compute_loss(self, batch: MockBatch) -> Tensor: @@ -172,6 +175,7 @@ def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: Mock assert isinstance(mock_trainer, MockTrainer) assert mock_trainer.optimizer is not None assert mock_trainer.lr_scheduler is not None + assert mock_trainer.model_registration_counter == 1 def test_training_cycle(mock_trainer: MockTrainer) -> None: