mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
refactor register_model decorator
This commit is contained in:
parent
d6546c9026
commit
cef8a9936c
|
@ -147,7 +147,6 @@ class OptimizerConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
checkpoint: Path | None = None
|
|
||||||
# If None, then requires_grad will NOT be changed when loading the model
|
# If None, then requires_grad will NOT be changed when loading the model
|
||||||
# this can be useful if you want to train only a part of the model
|
# this can be useful if you want to train only a part of the model
|
||||||
requires_grad: bool | None = None
|
requires_grad: bool | None = None
|
||||||
|
@ -163,7 +162,6 @@ T = TypeVar("T", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
models: dict[str, ModelConfig]
|
|
||||||
training: TrainingConfig
|
training: TrainingConfig
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
scheduler: SchedulerConfig
|
scheduler: SchedulerConfig
|
||||||
|
|
|
@ -430,7 +430,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
registered_callback(config)
|
registered_callback(config)
|
||||||
|
|
||||||
def _load_models(self) -> None:
|
def _load_models(self) -> None:
|
||||||
for name, config in self.config.models.items():
|
for name, config in self.config:
|
||||||
|
if not isinstance(config, ModelConfig):
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
registered_model = getattr(self, name)
|
registered_model = getattr(self, name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
[models.mock_model]
|
[mock_model]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
|
|
||||||
[clock]
|
[clock]
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
[models.mock_model1]
|
[mock_model1]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
learning_rate = 1e-5
|
learning_rate = 1e-5
|
||||||
|
|
||||||
[models.mock_model2]
|
[mock_model2]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
|
|
||||||
[clock]
|
[clock]
|
||||||
|
|
|
@ -28,7 +28,7 @@ class MockBatch:
|
||||||
|
|
||||||
|
|
||||||
class MockConfig(BaseConfig):
|
class MockConfig(BaseConfig):
|
||||||
pass
|
mock_model: ModelConfig
|
||||||
|
|
||||||
|
|
||||||
class MockModel(fl.Chain):
|
class MockModel(fl.Chain):
|
||||||
|
@ -230,9 +230,14 @@ class MockTrainerWith2Models(MockTrainer):
|
||||||
return norm(outputs - targets)
|
return norm(outputs - targets)
|
||||||
|
|
||||||
|
|
||||||
|
class MockConfig_2_Models(BaseConfig):
|
||||||
|
mock_model1: ModelConfig
|
||||||
|
mock_model2: ModelConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_2_models(test_device: torch.device) -> MockConfig:
|
def mock_config_2_models() -> MockConfig_2_Models:
|
||||||
return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
|
return MockConfig_2_Models.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
Loading…
Reference in a new issue