fix bug that was causing double registration

This commit is contained in:
limiteinductive 2024-02-13 06:29:55 +00:00 committed by Benjamin Trom
parent 3488273f50
commit ab506b4db2
2 changed files with 6 additions and 2 deletions

View file

@ -109,7 +109,7 @@ def register_model():
name=name, config=config, model=model, learnable_parameters=learnable_parameters
)
setattr(self, name, self.models[name].model)
return func(self, config)
return model
return wrapper # type: ignore
@ -129,7 +129,7 @@ def register_callback():
callback = func(self, config)
self.callbacks[name] = callback
setattr(self, name, callback)
return func(self, config)
return callback
return wrapper # type: ignore

View file

@ -50,6 +50,7 @@ class MockModel(fl.Chain):
class MockTrainer(Trainer[MockConfig, MockBatch]):
step_counter: int = 0
model_registration_counter: int = 0
@property
def dataset_length(self) -> int:
@ -69,6 +70,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
model = MockModel()
if config.use_activation:
model.add_activation()
self.model_registration_counter += 1
return model
def compute_loss(self, batch: MockBatch) -> Tensor:
@ -172,6 +175,7 @@ def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: Mock
assert isinstance(mock_trainer, MockTrainer)
assert mock_trainer.optimizer is not None
assert mock_trainer.lr_scheduler is not None
assert mock_trainer.model_registration_counter == 1
def test_training_cycle(mock_trainer: MockTrainer) -> None: