Enforce correct subtype for the config param in both decorators

Also add a custom ModelConfig for the MockTrainer test

Update src/refiners/training_utils/config.py

Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
limiteinductive 2024-02-12 14:53:24 +00:00 committed by Benjamin Trom
parent 0caa72a082
commit 3488273f50
4 changed files with 22 additions and 7 deletions

View file

@ -150,6 +150,7 @@ class ModelConfig(BaseModel):
# If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model
requires_grad: bool | None = None
# Optional, per-model optimizer parameters
learning_rate: float | None = None
betas: tuple[float, float] | None = None
eps: float | None = None

View file

@ -92,12 +92,13 @@ class ModelItem:
ModelRegistry = dict[str, ModelItem]
ModuleT = TypeVar("ModuleT", bound=fl.Module)
ModelConfigT = TypeVar("ModelConfigT", bound=ModelConfig)
def register_model():
def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT:
def decorator(func: Callable[[Any, ModelConfigT], ModuleT]) -> ModuleT:
@wraps(func)
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module:
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module:
name = func.__name__
model = func(self, config)
model = model.to(self.device, dtype=self.dtype)
@ -117,12 +118,13 @@ def register_model():
CallbackRegistry = dict[str, Callback[Any]]
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
CallbackConfigT = TypeVar("CallbackConfigT", bound=CallbackConfig)
def register_callback():
def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT:
def decorator(func: Callable[[Any, CallbackConfigT], CallbackT]) -> CallbackT:
@wraps(func)
def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT:
def wrapper(self: "Trainer[BaseConfig, Any]", config: CallbackConfigT) -> CallbackT:
name = func.__name__
callback = func(self, config)
self.callbacks[name] = callback

View file

@ -1,5 +1,6 @@
[mock_model]
requires_grad = true
use_activation = true
[clock]
verbose = false

View file

@ -27,8 +27,12 @@ class MockBatch:
targets: torch.Tensor
class MockModelConfig(ModelConfig):
use_activation: bool
class MockConfig(BaseConfig):
mock_model: ModelConfig
mock_model: MockModelConfig
class MockModel(fl.Chain):
@ -39,6 +43,10 @@ class MockModel(fl.Chain):
fl.Linear(10, 10),
)
def add_activation(self) -> None:
self.insert(1, fl.SiLU())
self.insert(3, fl.SiLU())
class MockTrainer(Trainer[MockConfig, MockBatch]):
step_counter: int = 0
@ -57,8 +65,11 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
)
@register_model()
def mock_model(self, config: ModelConfig) -> MockModel:
return MockModel()
def mock_model(self, config: MockModelConfig) -> MockModel:
model = MockModel()
if config.use_activation:
model.add_activation()
return model
def compute_loss(self, batch: MockBatch) -> Tensor:
self.step_counter += 1