diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 6b98d6f..eb6c25a 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -109,7 +109,7 @@ def register_model(): name=name, config=config, model=model, learnable_parameters=learnable_parameters ) setattr(self, name, self.models[name].model) - return func(self, config) + return model return wrapper # type: ignore @@ -129,7 +129,7 @@ def register_callback(): callback = func(self, config) self.callbacks[name] = callback setattr(self, name, callback) - return func(self, config) + return callback return wrapper # type: ignore diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index bf72a33..68d5a4b 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -50,6 +50,7 @@ class MockModel(fl.Chain): class MockTrainer(Trainer[MockConfig, MockBatch]): step_counter: int = 0 + model_registration_counter: int = 0 @property def dataset_length(self) -> int: @@ -69,6 +70,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): model = MockModel() if config.use_activation: model.add_activation() + + self.model_registration_counter += 1 return model def compute_loss(self, batch: MockBatch) -> Tensor: @@ -172,6 +175,7 @@ def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: Mock assert isinstance(mock_trainer, MockTrainer) assert mock_trainer.optimizer is not None assert mock_trainer.lr_scheduler is not None + assert mock_trainer.model_registration_counter == 1 def test_training_cycle(mock_trainer: MockTrainer) -> None: