lr, betas, eps, weight_decay at model level

Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Pierre Colle 2024-02-01 09:26:59 +01:00 committed by Cédric Deltheil
parent 9aefc9896c
commit 25bfa78907
4 changed files with 115 additions and 11 deletions

View file

@ -1,14 +1,18 @@
from enum import Enum
from logging import warn
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
from typing import Any, Callable, Literal, Type, TypeVar
import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, validator
from torch.nn import Parameter
from torch.optim import SGD, Adam, AdamW, Optimizer
try:
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
except ImportError as e:
from torch.optim.optimizer import ParamsT
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
import refiners.fluxion.layers as fl
@ -119,17 +123,17 @@ class OptimizerConfig(BaseModel):
eps: float = 1e-8
weight_decay: float = 0.0
def get(self, model_parameters: Iterable[Parameter]) -> Optimizer:
def get(self, params: ParamsT) -> Optimizer:
match self.optimizer:
case Optimizers.SGD:
return SGD(
params=model_parameters,
params=params,
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
case Optimizers.Adam:
return Adam(
params=model_parameters,
params=params,
lr=self.learning_rate,
betas=self.betas,
eps=self.eps,
@ -137,7 +141,7 @@ class OptimizerConfig(BaseModel):
)
case Optimizers.AdamW:
return AdamW(
params=model_parameters,
params=params,
lr=self.learning_rate,
betas=self.betas,
eps=self.eps,
@ -145,7 +149,7 @@ class OptimizerConfig(BaseModel):
)
case Optimizers.AdamW8bit:
return AdamW8bit(
params=model_parameters,
params=params,
lr=self.learning_rate,
betas=self.betas,
eps=self.eps,
@ -153,7 +157,7 @@ class OptimizerConfig(BaseModel):
)
case Optimizers.Lion8bit:
return Lion8bit(
params=model_parameters,
params=params,
lr=self.learning_rate,
betas=self.betas,
weight_decay=self.weight_decay, # type: ignore
@ -163,7 +167,7 @@ class OptimizerConfig(BaseModel):
warn("Prodigy learning rate is not 1.0, this might cause instability.")
return Prodigy(
lr=self.learning_rate,
params=model_parameters,
params=params,
betas=self.betas,
weight_decay=self.weight_decay, # type: ignore
safeguard_warmup=True,
@ -174,7 +178,10 @@ class OptimizerConfig(BaseModel):
class ModelConfig(BaseModel):
checkpoint: Path | None = None
train: bool = True
learning_rate: float | None = None # TODO: Implement this
learning_rate: float | None = None
betas: tuple[float, float] | None = None
eps: float | None = None
weight_decay: float | None = None
class GyroDropoutConfig(BaseModel):

View file

@ -329,6 +329,32 @@ class Trainer(Generic[ConfigType, Batch], ABC):
"""Returns a list of learnable parameters in all models"""
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
@property
def optimizer_parameters(self) -> list[dict[str, Any]]:
"""
Returns a list of `dict`-s containing the params and optimizer options for each model.
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
"""
params: list[dict[str, Any]] = []
for model_name, model in self.models.items():
model_params = [param for param in model.parameters() if param.requires_grad]
model_config = self.config.models[model_name]
model_optim_conf: dict[str, Any] = {}
if model_config.learning_rate is not None:
model_optim_conf["lr"] = model_config.learning_rate
if model_config.weight_decay is not None:
model_optim_conf["weight_decay"] = model_config.learning_rate
if model_config.betas is not None:
model_optim_conf["betas"] = model_config.learning_rate
if model_config.eps is not None:
model_optim_conf["eps"] = model_config.learning_rate
for param in model_params:
params.append({"params": param, **model_optim_conf})
return params
@property
def learnable_parameter_count(self) -> int:
"""Returns the number of learnable parameters in all models"""
@ -353,7 +379,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def optimizer(self) -> Optimizer:
formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters)
optimizer = self.config.optimizer.get(params=self.optimizer_parameters)
return optimizer
@cached_property

View file

@ -0,0 +1,34 @@
[models.mock_model1]
train = true
learning_rate = 1e-5
[models.mock_model2]
train = true
[training]
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0
evaluation_interval = "5:epoch"
evaluation_seed = 1
[optimizer]
optimizer = "SGD"
learning_rate = 1
momentum = 0.9
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
[dropout]
dropout = 0.0
[checkpointing]
save_interval = "10:epoch"
[wandb]
mode = "disabled"
project = "mock_project"

View file

@ -214,3 +214,39 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
optimizer = warmup_scheduler.optimizer
for group in optimizer.param_groups:
assert group["lr"] == 0.1
class MockTrainerWith2Models(MockTrainer):
@cached_property
def mock_model1(self) -> MockModel:
return MockModel()
@cached_property
def mock_model2(self) -> MockModel:
return MockModel()
def load_models(self) -> dict[str, fl.Module]:
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
def compute_loss(self, batch: MockBatch) -> Tensor:
self.step_counter += 1
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
outputs = self.mock_model2(self.mock_model1(inputs))
return norm(outputs - targets)
@pytest.fixture
def mock_config_2_models(test_device: torch.device) -> MockConfig:
return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
@pytest.fixture
def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2Models:
return MockTrainerWith2Models(config=mock_config_2_models)
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
assert (
len(mock_trainer_2_models.optimizer.param_groups) == 12
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5