diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 2ba12fe..8547e1e 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -150,6 +150,7 @@ class ModelConfig(BaseModel): # If None, then requires_grad will NOT be changed when loading the model # this can be useful if you want to train only a part of the model requires_grad: bool | None = None + # Optional, per-model optimizer parameters learning_rate: float | None = None betas: tuple[float, float] | None = None eps: float | None = None diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index c31d4d1..6b98d6f 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -92,12 +92,13 @@ class ModelItem: ModelRegistry = dict[str, ModelItem] ModuleT = TypeVar("ModuleT", bound=fl.Module) +ModelConfigT = TypeVar("ModelConfigT", bound=ModelConfig) def register_model(): - def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT: + def decorator(func: Callable[[Any, ModelConfigT], ModuleT]) -> ModuleT: @wraps(func) - def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module: + def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module: name = func.__name__ model = func(self, config) model = model.to(self.device, dtype=self.dtype) @@ -117,12 +118,13 @@ def register_model(): CallbackRegistry = dict[str, Callback[Any]] CallbackT = TypeVar("CallbackT", bound=Callback[Any]) +CallbackConfigT = TypeVar("CallbackConfigT", bound=CallbackConfig) def register_callback(): - def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT: + def decorator(func: Callable[[Any, CallbackConfigT], CallbackT]) -> CallbackT: @wraps(func) - def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT: + def wrapper(self: "Trainer[BaseConfig, Any]", config: CallbackConfigT) -> CallbackT: name = func.__name__ callback = func(self, config) self.callbacks[name] = callback diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 814dd0c..a5017a9 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -1,5 +1,6 @@ [mock_model] requires_grad = true +use_activation = true [clock] verbose = false diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index ca6f431..bf72a33 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -27,8 +27,12 @@ class MockBatch: targets: torch.Tensor +class MockModelConfig(ModelConfig): + use_activation: bool + + class MockConfig(BaseConfig): - mock_model: ModelConfig + mock_model: MockModelConfig class MockModel(fl.Chain): @@ -39,6 +43,10 @@ class MockModel(fl.Chain): fl.Linear(10, 10), ) + def add_activation(self) -> None: + self.insert(1, fl.SiLU()) + self.insert(3, fl.SiLU()) + class MockTrainer(Trainer[MockConfig, MockBatch]): step_counter: int = 0 @@ -57,8 +65,11 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): ) @register_model() - def mock_model(self, config: ModelConfig) -> MockModel: - return MockModel() + def mock_model(self, config: MockModelConfig) -> MockModel: + model = MockModel() + if config.use_activation: + model.add_activation() + return model def compute_loss(self, batch: MockBatch) -> Tensor: self.step_counter += 1