mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Enforce correct subtype for the config param in both decorators
Also add a custom ModelConfig for the MockTrainer test Update src/refiners/training_utils/config.py Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
0caa72a082
commit
3488273f50
|
@ -150,6 +150,7 @@ class ModelConfig(BaseModel):
|
||||||
# If None, then requires_grad will NOT be changed when loading the model
|
# If None, then requires_grad will NOT be changed when loading the model
|
||||||
# this can be useful if you want to train only a part of the model
|
# this can be useful if you want to train only a part of the model
|
||||||
requires_grad: bool | None = None
|
requires_grad: bool | None = None
|
||||||
|
# Optional, per-model optimizer parameters
|
||||||
learning_rate: float | None = None
|
learning_rate: float | None = None
|
||||||
betas: tuple[float, float] | None = None
|
betas: tuple[float, float] | None = None
|
||||||
eps: float | None = None
|
eps: float | None = None
|
||||||
|
|
|
@ -92,12 +92,13 @@ class ModelItem:
|
||||||
|
|
||||||
ModelRegistry = dict[str, ModelItem]
|
ModelRegistry = dict[str, ModelItem]
|
||||||
ModuleT = TypeVar("ModuleT", bound=fl.Module)
|
ModuleT = TypeVar("ModuleT", bound=fl.Module)
|
||||||
|
ModelConfigT = TypeVar("ModelConfigT", bound=ModelConfig)
|
||||||
|
|
||||||
|
|
||||||
def register_model():
|
def register_model():
|
||||||
def decorator(func: Callable[[Any, ModelConfig], ModuleT]) -> ModuleT:
|
def decorator(func: Callable[[Any, ModelConfigT], ModuleT]) -> ModuleT:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfig) -> fl.Module:
|
def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module:
|
||||||
name = func.__name__
|
name = func.__name__
|
||||||
model = func(self, config)
|
model = func(self, config)
|
||||||
model = model.to(self.device, dtype=self.dtype)
|
model = model.to(self.device, dtype=self.dtype)
|
||||||
|
@ -117,12 +118,13 @@ def register_model():
|
||||||
|
|
||||||
CallbackRegistry = dict[str, Callback[Any]]
|
CallbackRegistry = dict[str, Callback[Any]]
|
||||||
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
|
CallbackT = TypeVar("CallbackT", bound=Callback[Any])
|
||||||
|
CallbackConfigT = TypeVar("CallbackConfigT", bound=CallbackConfig)
|
||||||
|
|
||||||
|
|
||||||
def register_callback():
|
def register_callback():
|
||||||
def decorator(func: Callable[[Any, Any], CallbackT]) -> CallbackT:
|
def decorator(func: Callable[[Any, CallbackConfigT], CallbackT]) -> CallbackT:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(self: "Trainer[BaseConfig, Any]", config: Any) -> CallbackT:
|
def wrapper(self: "Trainer[BaseConfig, Any]", config: CallbackConfigT) -> CallbackT:
|
||||||
name = func.__name__
|
name = func.__name__
|
||||||
callback = func(self, config)
|
callback = func(self, config)
|
||||||
self.callbacks[name] = callback
|
self.callbacks[name] = callback
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
[mock_model]
|
[mock_model]
|
||||||
requires_grad = true
|
requires_grad = true
|
||||||
|
use_activation = true
|
||||||
|
|
||||||
[clock]
|
[clock]
|
||||||
verbose = false
|
verbose = false
|
||||||
|
|
|
@ -27,8 +27,12 @@ class MockBatch:
|
||||||
targets: torch.Tensor
|
targets: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MockModelConfig(ModelConfig):
|
||||||
|
use_activation: bool
|
||||||
|
|
||||||
|
|
||||||
class MockConfig(BaseConfig):
|
class MockConfig(BaseConfig):
|
||||||
mock_model: ModelConfig
|
mock_model: MockModelConfig
|
||||||
|
|
||||||
|
|
||||||
class MockModel(fl.Chain):
|
class MockModel(fl.Chain):
|
||||||
|
@ -39,6 +43,10 @@ class MockModel(fl.Chain):
|
||||||
fl.Linear(10, 10),
|
fl.Linear(10, 10),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_activation(self) -> None:
|
||||||
|
self.insert(1, fl.SiLU())
|
||||||
|
self.insert(3, fl.SiLU())
|
||||||
|
|
||||||
|
|
||||||
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
step_counter: int = 0
|
step_counter: int = 0
|
||||||
|
@ -57,8 +65,11 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
|
||||||
)
|
)
|
||||||
|
|
||||||
@register_model()
|
@register_model()
|
||||||
def mock_model(self, config: ModelConfig) -> MockModel:
|
def mock_model(self, config: MockModelConfig) -> MockModel:
|
||||||
return MockModel()
|
model = MockModel()
|
||||||
|
if config.use_activation:
|
||||||
|
model.add_activation()
|
||||||
|
return model
|
||||||
|
|
||||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||||
self.step_counter += 1
|
self.step_counter += 1
|
||||||
|
|
Loading…
Reference in a new issue