mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix bug that was causing double registration
This commit is contained in:
parent
3488273f50
commit
ab506b4db2
|
@ -109,7 +109,7 @@ def register_model():
|
|||
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
||||
)
|
||||
setattr(self, name, self.models[name].model)
|
||||
return func(self, config)
|
||||
return model
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
@ -129,7 +129,7 @@ def register_callback():
|
|||
callback = func(self, config)
|
||||
self.callbacks[name] = callback
|
||||
setattr(self, name, callback)
|
||||
return func(self, config)
|
||||
return callback
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ class MockModel(fl.Chain):
|
|||
|
||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||
step_counter: int = 0
|
||||
model_registration_counter: int = 0
|
||||
|
||||
@property
|
||||
def dataset_length(self) -> int:
|
||||
|
@ -69,6 +70,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
|||
model = MockModel()
|
||||
if config.use_activation:
|
||||
model.add_activation()
|
||||
|
||||
self.model_registration_counter += 1
|
||||
return model
|
||||
|
||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||
|
@ -172,6 +175,7 @@ def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: Mock
|
|||
assert isinstance(mock_trainer, MockTrainer)
|
||||
assert mock_trainer.optimizer is not None
|
||||
assert mock_trainer.lr_scheduler is not None
|
||||
assert mock_trainer.model_registration_counter == 1
|
||||
|
||||
|
||||
def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||
|
|
Loading…
Reference in a new issue