diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 9a4c9a5..2ba12fe 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -147,7 +147,6 @@ class OptimizerConfig(BaseModel): class ModelConfig(BaseModel): - checkpoint: Path | None = None # If None, then requires_grad will NOT be changed when loading the model # this can be useful if you want to train only a part of the model requires_grad: bool | None = None @@ -163,7 +162,6 @@ T = TypeVar("T", bound="BaseConfig") class BaseConfig(BaseModel): - models: dict[str, ModelConfig] training: TrainingConfig optimizer: OptimizerConfig scheduler: SchedulerConfig diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index cb3d2b6..c31d4d1 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -430,7 +430,9 @@ class Trainer(Generic[ConfigType, Batch], ABC): registered_callback(config) def _load_models(self) -> None: - for name, config in self.config.models.items(): + for name, config in self.config: + if not isinstance(config, ModelConfig): + continue try: registered_model = getattr(self, name) except AttributeError: diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 06652b5..814dd0c 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -1,4 +1,4 @@ -[models.mock_model] +[mock_model] requires_grad = true [clock] diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index e279464..474551f 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -1,8 +1,8 @@ -[models.mock_model1] +[mock_model1] requires_grad = true learning_rate = 1e-5 -[models.mock_model2] +[mock_model2] requires_grad = true [clock] diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index c1f8b08..ca6f431 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -28,7 +28,7 @@ class MockBatch: class MockConfig(BaseConfig): - pass + mock_model: ModelConfig class MockModel(fl.Chain): @@ -230,9 +230,14 @@ class MockTrainerWith2Models(MockTrainer): return norm(outputs - targets) +class MockConfig_2_Models(BaseConfig): + mock_model1: ModelConfig + mock_model2: ModelConfig + + @pytest.fixture -def mock_config_2_models(test_device: torch.device) -> MockConfig: - return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml") +def mock_config_2_models() -> MockConfig_2_Models: + return MockConfig_2_Models.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml") @pytest.fixture