mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix bug that was causing double registration
This commit is contained in:
parent
3488273f50
commit
ab506b4db2
|
@ -109,7 +109,7 @@ def register_model():
|
||||||
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
name=name, config=config, model=model, learnable_parameters=learnable_parameters
|
||||||
)
|
)
|
||||||
setattr(self, name, self.models[name].model)
|
setattr(self, name, self.models[name].model)
|
||||||
return func(self, config)
|
return model
|
||||||
|
|
||||||
return wrapper # type: ignore
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ def register_callback():
|
||||||
callback = func(self, config)
|
callback = func(self, config)
|
||||||
self.callbacks[name] = callback
|
self.callbacks[name] = callback
|
||||||
setattr(self, name, callback)
|
setattr(self, name, callback)
|
||||||
return func(self, config)
|
return callback
|
||||||
|
|
||||||
return wrapper # type: ignore
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ class MockModel(fl.Chain):
|
||||||
|
|
||||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
step_counter: int = 0
|
step_counter: int = 0
|
||||||
|
model_registration_counter: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_length(self) -> int:
|
def dataset_length(self) -> int:
|
||||||
|
@ -69,6 +70,8 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
model = MockModel()
|
model = MockModel()
|
||||||
if config.use_activation:
|
if config.use_activation:
|
||||||
model.add_activation()
|
model.add_activation()
|
||||||
|
|
||||||
|
self.model_registration_counter += 1
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||||
|
@ -172,6 +175,7 @@ def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: Mock
|
||||||
assert isinstance(mock_trainer, MockTrainer)
|
assert isinstance(mock_trainer, MockTrainer)
|
||||||
assert mock_trainer.optimizer is not None
|
assert mock_trainer.optimizer is not None
|
||||||
assert mock_trainer.lr_scheduler is not None
|
assert mock_trainer.lr_scheduler is not None
|
||||||
|
assert mock_trainer.model_registration_counter == 1
|
||||||
|
|
||||||
|
|
||||||
def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
def test_training_cycle(mock_trainer: MockTrainer) -> None:
|
||||||
|
|
Loading…
Reference in a new issue