lr, betas, eps, weight_decay at model level

Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Pierre Colle 2024-02-01 09:26:59 +01:00 committed by Cédric Deltheil
parent 9aefc9896c
commit 25bfa78907
4 changed files with 115 additions and 11 deletions

View file

@ -1,14 +1,18 @@
from enum import Enum from enum import Enum
from logging import warn from logging import warn
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Type, TypeVar from typing import Any, Callable, Literal, Type, TypeVar
import tomli import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from torch.nn import Parameter
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
try:
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
except ImportError as e:
from torch.optim.optimizer import ParamsT
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
@ -119,17 +123,17 @@ class OptimizerConfig(BaseModel):
eps: float = 1e-8 eps: float = 1e-8
weight_decay: float = 0.0 weight_decay: float = 0.0
def get(self, model_parameters: Iterable[Parameter]) -> Optimizer: def get(self, params: ParamsT) -> Optimizer:
match self.optimizer: match self.optimizer:
case Optimizers.SGD: case Optimizers.SGD:
return SGD( return SGD(
params=model_parameters, params=params,
lr=self.learning_rate, lr=self.learning_rate,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
) )
case Optimizers.Adam: case Optimizers.Adam:
return Adam( return Adam(
params=model_parameters, params=params,
lr=self.learning_rate, lr=self.learning_rate,
betas=self.betas, betas=self.betas,
eps=self.eps, eps=self.eps,
@ -137,7 +141,7 @@ class OptimizerConfig(BaseModel):
) )
case Optimizers.AdamW: case Optimizers.AdamW:
return AdamW( return AdamW(
params=model_parameters, params=params,
lr=self.learning_rate, lr=self.learning_rate,
betas=self.betas, betas=self.betas,
eps=self.eps, eps=self.eps,
@ -145,7 +149,7 @@ class OptimizerConfig(BaseModel):
) )
case Optimizers.AdamW8bit: case Optimizers.AdamW8bit:
return AdamW8bit( return AdamW8bit(
params=model_parameters, params=params,
lr=self.learning_rate, lr=self.learning_rate,
betas=self.betas, betas=self.betas,
eps=self.eps, eps=self.eps,
@ -153,7 +157,7 @@ class OptimizerConfig(BaseModel):
) )
case Optimizers.Lion8bit: case Optimizers.Lion8bit:
return Lion8bit( return Lion8bit(
params=model_parameters, params=params,
lr=self.learning_rate, lr=self.learning_rate,
betas=self.betas, betas=self.betas,
weight_decay=self.weight_decay, # type: ignore weight_decay=self.weight_decay, # type: ignore
@ -163,7 +167,7 @@ class OptimizerConfig(BaseModel):
warn("Prodigy learning rate is not 1.0, this might cause instability.") warn("Prodigy learning rate is not 1.0, this might cause instability.")
return Prodigy( return Prodigy(
lr=self.learning_rate, lr=self.learning_rate,
params=model_parameters, params=params,
betas=self.betas, betas=self.betas,
weight_decay=self.weight_decay, # type: ignore weight_decay=self.weight_decay, # type: ignore
safeguard_warmup=True, safeguard_warmup=True,
@ -174,7 +178,10 @@ class OptimizerConfig(BaseModel):
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
checkpoint: Path | None = None checkpoint: Path | None = None
train: bool = True train: bool = True
learning_rate: float | None = None # TODO: Implement this learning_rate: float | None = None
betas: tuple[float, float] | None = None
eps: float | None = None
weight_decay: float | None = None
class GyroDropoutConfig(BaseModel): class GyroDropoutConfig(BaseModel):

View file

@ -329,6 +329,32 @@ class Trainer(Generic[ConfigType, Batch], ABC):
"""Returns a list of learnable parameters in all models""" """Returns a list of learnable parameters in all models"""
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad] return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
@property
def optimizer_parameters(self) -> list[dict[str, Any]]:
"""
Returns a list of `dict`-s containing the params and optimizer options for each model.
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
"""
params: list[dict[str, Any]] = []
for model_name, model in self.models.items():
model_params = [param for param in model.parameters() if param.requires_grad]
model_config = self.config.models[model_name]
model_optim_conf: dict[str, Any] = {}
if model_config.learning_rate is not None:
model_optim_conf["lr"] = model_config.learning_rate
if model_config.weight_decay is not None:
model_optim_conf["weight_decay"] = model_config.learning_rate
if model_config.betas is not None:
model_optim_conf["betas"] = model_config.learning_rate
if model_config.eps is not None:
model_optim_conf["eps"] = model_config.learning_rate
for param in model_params:
params.append({"params": param, **model_optim_conf})
return params
@property @property
def learnable_parameter_count(self) -> int: def learnable_parameter_count(self) -> int:
"""Returns the number of learnable parameters in all models""" """Returns the number of learnable parameters in all models"""
@ -353,7 +379,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
def optimizer(self) -> Optimizer: def optimizer(self) -> Optimizer:
formatted_param_count = human_readable_number(number=self.learnable_parameter_count) formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}") logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters)
optimizer = self.config.optimizer.get(params=self.optimizer_parameters)
return optimizer return optimizer
@cached_property @cached_property

View file

@ -0,0 +1,34 @@
[models.mock_model1]
train = true
learning_rate = 1e-5
[models.mock_model2]
train = true
[training]
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
clip_grad_norm = 1.0
evaluation_interval = "5:epoch"
evaluation_seed = 1
[optimizer]
optimizer = "SGD"
learning_rate = 1
momentum = 0.9
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
[dropout]
dropout = 0.0
[checkpointing]
save_interval = "10:epoch"
[wandb]
mode = "disabled"
project = "mock_project"

View file

@ -214,3 +214,39 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
optimizer = warmup_scheduler.optimizer optimizer = warmup_scheduler.optimizer
for group in optimizer.param_groups: for group in optimizer.param_groups:
assert group["lr"] == 0.1 assert group["lr"] == 0.1
class MockTrainerWith2Models(MockTrainer):
@cached_property
def mock_model1(self) -> MockModel:
return MockModel()
@cached_property
def mock_model2(self) -> MockModel:
return MockModel()
def load_models(self) -> dict[str, fl.Module]:
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
def compute_loss(self, batch: MockBatch) -> Tensor:
self.step_counter += 1
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
outputs = self.mock_model2(self.mock_model1(inputs))
return norm(outputs - targets)
@pytest.fixture
def mock_config_2_models(test_device: torch.device) -> MockConfig:
return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
@pytest.fixture
def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2Models:
return MockTrainerWith2Models(config=mock_config_2_models)
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
assert (
len(mock_trainer_2_models.optimizer.param_groups) == 12
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5