mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
lr, betas, eps, weight_decay at model level
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
9aefc9896c
commit
25bfa78907
|
@ -1,14 +1,18 @@
|
|||
from enum import Enum
|
||||
from logging import warn
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||
from typing import Any, Callable, Literal, Type, TypeVar
|
||||
|
||||
import tomli
|
||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||
from prodigyopt import Prodigy # type: ignore
|
||||
from pydantic import BaseModel, validator
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
|
||||
try:
|
||||
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
|
||||
except ImportError as e:
|
||||
from torch.optim.optimizer import ParamsT
|
||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
|
@ -119,17 +123,17 @@ class OptimizerConfig(BaseModel):
|
|||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
|
||||
def get(self, model_parameters: Iterable[Parameter]) -> Optimizer:
|
||||
def get(self, params: ParamsT) -> Optimizer:
|
||||
match self.optimizer:
|
||||
case Optimizers.SGD:
|
||||
return SGD(
|
||||
params=model_parameters,
|
||||
params=params,
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
case Optimizers.Adam:
|
||||
return Adam(
|
||||
params=model_parameters,
|
||||
params=params,
|
||||
lr=self.learning_rate,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
|
@ -137,7 +141,7 @@ class OptimizerConfig(BaseModel):
|
|||
)
|
||||
case Optimizers.AdamW:
|
||||
return AdamW(
|
||||
params=model_parameters,
|
||||
params=params,
|
||||
lr=self.learning_rate,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
|
@ -145,7 +149,7 @@ class OptimizerConfig(BaseModel):
|
|||
)
|
||||
case Optimizers.AdamW8bit:
|
||||
return AdamW8bit(
|
||||
params=model_parameters,
|
||||
params=params,
|
||||
lr=self.learning_rate,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
|
@ -153,7 +157,7 @@ class OptimizerConfig(BaseModel):
|
|||
)
|
||||
case Optimizers.Lion8bit:
|
||||
return Lion8bit(
|
||||
params=model_parameters,
|
||||
params=params,
|
||||
lr=self.learning_rate,
|
||||
betas=self.betas,
|
||||
weight_decay=self.weight_decay, # type: ignore
|
||||
|
@ -163,7 +167,7 @@ class OptimizerConfig(BaseModel):
|
|||
warn("Prodigy learning rate is not 1.0, this might cause instability.")
|
||||
return Prodigy(
|
||||
lr=self.learning_rate,
|
||||
params=model_parameters,
|
||||
params=params,
|
||||
betas=self.betas,
|
||||
weight_decay=self.weight_decay, # type: ignore
|
||||
safeguard_warmup=True,
|
||||
|
@ -174,7 +178,10 @@ class OptimizerConfig(BaseModel):
|
|||
class ModelConfig(BaseModel):
|
||||
checkpoint: Path | None = None
|
||||
train: bool = True
|
||||
learning_rate: float | None = None # TODO: Implement this
|
||||
learning_rate: float | None = None
|
||||
betas: tuple[float, float] | None = None
|
||||
eps: float | None = None
|
||||
weight_decay: float | None = None
|
||||
|
||||
|
||||
class GyroDropoutConfig(BaseModel):
|
||||
|
|
|
@ -329,6 +329,32 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
"""Returns a list of learnable parameters in all models"""
|
||||
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
|
||||
|
||||
@property
|
||||
def optimizer_parameters(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Returns a list of `dict`-s containing the params and optimizer options for each model.
|
||||
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
|
||||
"""
|
||||
params: list[dict[str, Any]] = []
|
||||
for model_name, model in self.models.items():
|
||||
model_params = [param for param in model.parameters() if param.requires_grad]
|
||||
model_config = self.config.models[model_name]
|
||||
model_optim_conf: dict[str, Any] = {}
|
||||
|
||||
if model_config.learning_rate is not None:
|
||||
model_optim_conf["lr"] = model_config.learning_rate
|
||||
if model_config.weight_decay is not None:
|
||||
model_optim_conf["weight_decay"] = model_config.learning_rate
|
||||
if model_config.betas is not None:
|
||||
model_optim_conf["betas"] = model_config.learning_rate
|
||||
if model_config.eps is not None:
|
||||
model_optim_conf["eps"] = model_config.learning_rate
|
||||
|
||||
for param in model_params:
|
||||
params.append({"params": param, **model_optim_conf})
|
||||
|
||||
return params
|
||||
|
||||
@property
|
||||
def learnable_parameter_count(self) -> int:
|
||||
"""Returns the number of learnable parameters in all models"""
|
||||
|
@ -353,7 +379,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
def optimizer(self) -> Optimizer:
|
||||
formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
|
||||
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
|
||||
optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters)
|
||||
|
||||
optimizer = self.config.optimizer.get(params=self.optimizer_parameters)
|
||||
return optimizer
|
||||
|
||||
@cached_property
|
||||
|
|
34
tests/training_utils/mock_config_2_models.toml
Normal file
34
tests/training_utils/mock_config_2_models.toml
Normal file
|
@ -0,0 +1,34 @@
|
|||
[models.mock_model1]
|
||||
train = true
|
||||
learning_rate = 1e-5
|
||||
|
||||
[models.mock_model2]
|
||||
train = true
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
seed = 0
|
||||
batch_size = 4
|
||||
gradient_accumulation = "4:step"
|
||||
clip_grad_norm = 1.0
|
||||
evaluation_interval = "5:epoch"
|
||||
evaluation_seed = 1
|
||||
|
||||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
learning_rate = 1
|
||||
momentum = 0.9
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
|
||||
[dropout]
|
||||
dropout = 0.0
|
||||
|
||||
[checkpointing]
|
||||
save_interval = "10:epoch"
|
||||
|
||||
[wandb]
|
||||
mode = "disabled"
|
||||
project = "mock_project"
|
|
@ -214,3 +214,39 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
|||
optimizer = warmup_scheduler.optimizer
|
||||
for group in optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
||||
|
||||
class MockTrainerWith2Models(MockTrainer):
|
||||
@cached_property
|
||||
def mock_model1(self) -> MockModel:
|
||||
return MockModel()
|
||||
|
||||
@cached_property
|
||||
def mock_model2(self) -> MockModel:
|
||||
return MockModel()
|
||||
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
|
||||
|
||||
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||
self.step_counter += 1
|
||||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||
outputs = self.mock_model2(self.mock_model1(inputs))
|
||||
return norm(outputs - targets)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_2_models(test_device: torch.device) -> MockConfig:
|
||||
return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2Models:
|
||||
return MockTrainerWith2Models(config=mock_config_2_models)
|
||||
|
||||
|
||||
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
||||
assert (
|
||||
len(mock_trainer_2_models.optimizer.param_groups) == 12
|
||||
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
|
||||
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
||||
|
|
Loading…
Reference in a new issue