mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
lr, betas, eps, weight_decay at model level
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
9aefc9896c
commit
25bfa78907
|
@ -1,14 +1,18 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import warn
|
from logging import warn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
from typing import Any, Callable, Literal, Type, TypeVar
|
||||||
|
|
||||||
import tomli
|
import tomli
|
||||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||||
from prodigyopt import Prodigy # type: ignore
|
from prodigyopt import Prodigy # type: ignore
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from torch.nn import Parameter
|
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
|
||||||
|
except ImportError as e:
|
||||||
|
from torch.optim.optimizer import ParamsT
|
||||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
@ -119,17 +123,17 @@ class OptimizerConfig(BaseModel):
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
|
|
||||||
def get(self, model_parameters: Iterable[Parameter]) -> Optimizer:
|
def get(self, params: ParamsT) -> Optimizer:
|
||||||
match self.optimizer:
|
match self.optimizer:
|
||||||
case Optimizers.SGD:
|
case Optimizers.SGD:
|
||||||
return SGD(
|
return SGD(
|
||||||
params=model_parameters,
|
params=params,
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
weight_decay=self.weight_decay,
|
weight_decay=self.weight_decay,
|
||||||
)
|
)
|
||||||
case Optimizers.Adam:
|
case Optimizers.Adam:
|
||||||
return Adam(
|
return Adam(
|
||||||
params=model_parameters,
|
params=params,
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
betas=self.betas,
|
betas=self.betas,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
|
@ -137,7 +141,7 @@ class OptimizerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
case Optimizers.AdamW:
|
case Optimizers.AdamW:
|
||||||
return AdamW(
|
return AdamW(
|
||||||
params=model_parameters,
|
params=params,
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
betas=self.betas,
|
betas=self.betas,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
|
@ -145,7 +149,7 @@ class OptimizerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
case Optimizers.AdamW8bit:
|
case Optimizers.AdamW8bit:
|
||||||
return AdamW8bit(
|
return AdamW8bit(
|
||||||
params=model_parameters,
|
params=params,
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
betas=self.betas,
|
betas=self.betas,
|
||||||
eps=self.eps,
|
eps=self.eps,
|
||||||
|
@ -153,7 +157,7 @@ class OptimizerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
case Optimizers.Lion8bit:
|
case Optimizers.Lion8bit:
|
||||||
return Lion8bit(
|
return Lion8bit(
|
||||||
params=model_parameters,
|
params=params,
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
betas=self.betas,
|
betas=self.betas,
|
||||||
weight_decay=self.weight_decay, # type: ignore
|
weight_decay=self.weight_decay, # type: ignore
|
||||||
|
@ -163,7 +167,7 @@ class OptimizerConfig(BaseModel):
|
||||||
warn("Prodigy learning rate is not 1.0, this might cause instability.")
|
warn("Prodigy learning rate is not 1.0, this might cause instability.")
|
||||||
return Prodigy(
|
return Prodigy(
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
params=model_parameters,
|
params=params,
|
||||||
betas=self.betas,
|
betas=self.betas,
|
||||||
weight_decay=self.weight_decay, # type: ignore
|
weight_decay=self.weight_decay, # type: ignore
|
||||||
safeguard_warmup=True,
|
safeguard_warmup=True,
|
||||||
|
@ -174,7 +178,10 @@ class OptimizerConfig(BaseModel):
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
checkpoint: Path | None = None
|
checkpoint: Path | None = None
|
||||||
train: bool = True
|
train: bool = True
|
||||||
learning_rate: float | None = None # TODO: Implement this
|
learning_rate: float | None = None
|
||||||
|
betas: tuple[float, float] | None = None
|
||||||
|
eps: float | None = None
|
||||||
|
weight_decay: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class GyroDropoutConfig(BaseModel):
|
class GyroDropoutConfig(BaseModel):
|
||||||
|
|
|
@ -329,6 +329,32 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
"""Returns a list of learnable parameters in all models"""
|
"""Returns a list of learnable parameters in all models"""
|
||||||
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
|
return [param for model in self.models.values() for param in model.parameters() if param.requires_grad]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optimizer_parameters(self) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Returns a list of `dict`-s containing the params and optimizer options for each model.
|
||||||
|
See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details
|
||||||
|
"""
|
||||||
|
params: list[dict[str, Any]] = []
|
||||||
|
for model_name, model in self.models.items():
|
||||||
|
model_params = [param for param in model.parameters() if param.requires_grad]
|
||||||
|
model_config = self.config.models[model_name]
|
||||||
|
model_optim_conf: dict[str, Any] = {}
|
||||||
|
|
||||||
|
if model_config.learning_rate is not None:
|
||||||
|
model_optim_conf["lr"] = model_config.learning_rate
|
||||||
|
if model_config.weight_decay is not None:
|
||||||
|
model_optim_conf["weight_decay"] = model_config.learning_rate
|
||||||
|
if model_config.betas is not None:
|
||||||
|
model_optim_conf["betas"] = model_config.learning_rate
|
||||||
|
if model_config.eps is not None:
|
||||||
|
model_optim_conf["eps"] = model_config.learning_rate
|
||||||
|
|
||||||
|
for param in model_params:
|
||||||
|
params.append({"params": param, **model_optim_conf})
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def learnable_parameter_count(self) -> int:
|
def learnable_parameter_count(self) -> int:
|
||||||
"""Returns the number of learnable parameters in all models"""
|
"""Returns the number of learnable parameters in all models"""
|
||||||
|
@ -353,7 +379,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
def optimizer(self) -> Optimizer:
|
def optimizer(self) -> Optimizer:
|
||||||
formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
|
formatted_param_count = human_readable_number(number=self.learnable_parameter_count)
|
||||||
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
|
logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}")
|
||||||
optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters)
|
|
||||||
|
optimizer = self.config.optimizer.get(params=self.optimizer_parameters)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
|
34
tests/training_utils/mock_config_2_models.toml
Normal file
34
tests/training_utils/mock_config_2_models.toml
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
[models.mock_model1]
|
||||||
|
train = true
|
||||||
|
learning_rate = 1e-5
|
||||||
|
|
||||||
|
[models.mock_model2]
|
||||||
|
train = true
|
||||||
|
|
||||||
|
[training]
|
||||||
|
duration = "100:epoch"
|
||||||
|
seed = 0
|
||||||
|
batch_size = 4
|
||||||
|
gradient_accumulation = "4:step"
|
||||||
|
clip_grad_norm = 1.0
|
||||||
|
evaluation_interval = "5:epoch"
|
||||||
|
evaluation_seed = 1
|
||||||
|
|
||||||
|
[optimizer]
|
||||||
|
optimizer = "SGD"
|
||||||
|
learning_rate = 1
|
||||||
|
momentum = 0.9
|
||||||
|
|
||||||
|
[scheduler]
|
||||||
|
scheduler_type = "ConstantLR"
|
||||||
|
update_interval = "1:step"
|
||||||
|
|
||||||
|
[dropout]
|
||||||
|
dropout = 0.0
|
||||||
|
|
||||||
|
[checkpointing]
|
||||||
|
save_interval = "10:epoch"
|
||||||
|
|
||||||
|
[wandb]
|
||||||
|
mode = "disabled"
|
||||||
|
project = "mock_project"
|
|
@ -214,3 +214,39 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||||
optimizer = warmup_scheduler.optimizer
|
optimizer = warmup_scheduler.optimizer
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
assert group["lr"] == 0.1
|
assert group["lr"] == 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class MockTrainerWith2Models(MockTrainer):
|
||||||
|
@cached_property
|
||||||
|
def mock_model1(self) -> MockModel:
|
||||||
|
return MockModel()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def mock_model2(self) -> MockModel:
|
||||||
|
return MockModel()
|
||||||
|
|
||||||
|
def load_models(self) -> dict[str, fl.Module]:
|
||||||
|
return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2}
|
||||||
|
|
||||||
|
def compute_loss(self, batch: MockBatch) -> Tensor:
|
||||||
|
self.step_counter += 1
|
||||||
|
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device)
|
||||||
|
outputs = self.mock_model2(self.mock_model1(inputs))
|
||||||
|
return norm(outputs - targets)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_2_models(test_device: torch.device) -> MockConfig:
|
||||||
|
return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2Models:
|
||||||
|
return MockTrainerWith2Models(config=mock_config_2_models)
|
||||||
|
|
||||||
|
|
||||||
|
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
|
||||||
|
assert (
|
||||||
|
len(mock_trainer_2_models.optimizer.param_groups) == 12
|
||||||
|
) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models]
|
||||||
|
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
|
||||||
|
|
Loading…
Reference in a new issue