mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
Enforce correct subtype for the config param in both decorators
Also add a custom ModelConfig for the MockTrainer test Update src/refiners/training_utils/config.py Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
0caa72a082
commit
3488273f50
|
@ -150,6 +150,7 @@ class ModelConfig(BaseModel):
|
|||
# If None, then requires_grad will NOT be changed when loading the model
|
||||
# this can be useful if you want to train only a part of the model
|
||||
requires_grad: bool | None = None
|
||||
# Optional, per-model optimizer parameters
|
||||
learning_rate: float | None = None
|
||||
betas: tuple[float, float] | None = None
|
||||
eps: float | None = None
|
||||
|
|
|
@ -92,12 +92,13 @@ class ModelItem:
|
|||
|
||||
ModelRegistry = dict[str, ModelItem]
|
||||
ModuleT = TypeVar("ModuleT", bound=fl.Module)
|
||||
ModelConfigT = TypeVar("ModelConfigT", bound=ModelConfig)
|
||||
|
||||
|
||||
def register_model():
|
||||
def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT:
|
||||
def decorator(func: Callable[[Any, ModelConfigT], ModuleT]) -> ModuleT:
|
||||
@wraps(func)
|
||||
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module:
|
||||
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module:
|
||||
name = func.__name__
|
||||
model = func(self, config)
|
||||
model = model.to(self.device, dtype=self.dtype)
|
||||
|
@ -117,12 +118,13 @@ def register_model():
|
|||
|
||||
CallbackRegistry = dict[str, Callback[Any]]
|
||||
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
|
||||
CallbackConfigT = TypeVar("CallbackConfigT", bound=CallbackConfig)
|
||||
|
||||
|
||||
def register_callback():
|
||||
def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT:
|
||||
def decorator(func: Callable[[Any, CallbackConfigT], CallbackT]) -> CallbackT:
|
||||
@wraps(func)
|
||||
def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT:
|
||||
def wrapper(self: "Trainer[BaseConfig, Any]", config: CallbackConfigT) -> CallbackT:
|
||||
name = func.__name__
|
||||
callback = func(self, config)
|
||||
self.callbacks[name] = callback
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
[mock_model]
|
||||
requires_grad = true
|
||||
use_activation = true
|
||||
|
||||
[clock]
|
||||
verbose = false
|
||||
|
|
|
@ -27,8 +27,12 @@ class MockBatch:
|
|||
targets: torch.Tensor
|
||||
|
||||
|
||||
class MockModelConfig(ModelConfig):
|
||||
use_activation: bool
|
||||
|
||||
|
||||
class MockConfig(BaseConfig):
|
||||
mock_model: ModelConfig
|
||||
mock_model: MockModelConfig
|
||||
|
||||
|
||||
class MockModel(fl.Chain):
|
||||
|
@ -39,6 +43,10 @@ class MockModel(fl.Chain):
|
|||
fl.Linear(10, 10),
|
||||
)
|
||||
|
||||
def add_activation(self) -> None:
|
||||
self.insert(1, fl.SiLU())
|
||||
self.insert(3, fl.SiLU())
|
||||
|
||||
|
||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||
step_counter: int = 0
|
||||
|
@ -57,8 +65,11 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
|||
)
|
||||
|
||||
@register_model()
|
||||
def mock_model(self, config: ModelConfig) -> MockModel:
|
||||
return MockModel()
|
||||
def mock_model(self, config: MockModelConfig) -> MockModel:
|
||||
model = MockModel()
|
||||
if config.use_activation:
|
||||
model.add_activation()
|
||||
return model
|
||||
|
||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||
self.step_counter += 1
|
||||
|
|
Loading…
Reference in a new issue