refactor register_model decorator

This commit is contained in:
limiteinductive 2024-02-12 13:17:51 +00:00 committed by Benjamin Trom
parent d6546c9026
commit cef8a9936c
5 changed files with 14 additions and 9 deletions

View file

@ -147,7 +147,6 @@ class OptimizerConfig(BaseModel):
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
checkpoint: Path | None = None
# If None, then requires_grad will NOT be changed when loading the model # If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model # this can be useful if you want to train only a part of the model
requires_grad: bool | None = None requires_grad: bool | None = None
@ -163,7 +162,6 @@ T = TypeVar("T", bound="BaseConfig")
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
models: dict[str, ModelConfig]
training: TrainingConfig training: TrainingConfig
optimizer: OptimizerConfig optimizer: OptimizerConfig
scheduler: SchedulerConfig scheduler: SchedulerConfig

View file

@ -430,7 +430,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
registered_callback(config) registered_callback(config)
def _load_models(self) -> None: def _load_models(self) -> None:
for name, config in self.config.models.items(): for name, config in self.config:
if not isinstance(config, ModelConfig):
continue
try: try:
registered_model = getattr(self, name) registered_model = getattr(self, name)
except AttributeError: except AttributeError:

View file

@ -1,4 +1,4 @@
[models.mock_model] [mock_model]
requires_grad = true requires_grad = true
[clock] [clock]

View file

@ -1,8 +1,8 @@
[models.mock_model1] [mock_model1]
requires_grad = true requires_grad = true
learning_rate = 1e-5 learning_rate = 1e-5
[models.mock_model2] [mock_model2]
requires_grad = true requires_grad = true
[clock] [clock]

View file

@ -28,7 +28,7 @@ class MockBatch:
class MockConfig(BaseConfig): class MockConfig(BaseConfig):
pass mock_model: ModelConfig
class MockModel(fl.Chain): class MockModel(fl.Chain):
@ -230,9 +230,14 @@ class MockTrainerWith2Models(MockTrainer):
return norm(outputs - targets) return norm(outputs - targets)
class MockConfig_2_Models(BaseConfig):
mock_model1: ModelConfig
mock_model2: ModelConfig
@pytest.fixture @pytest.fixture
def mock_config_2_models(test_device: torch.device) -> MockConfig: def mock_config_2_models() -> MockConfig_2_Models:
return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml") return MockConfig_2_Models.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
@pytest.fixture @pytest.fixture