diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 0901f5d..afaa641 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -1,14 +1,18 @@ from enum import Enum from logging import warn from pathlib import Path -from typing import Any, Callable, Iterable, Literal, Type, TypeVar +from typing import Any, Callable, Literal, Type, TypeVar import tomli from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore from pydantic import BaseModel, validator -from torch.nn import Parameter from torch.optim import SGD, Adam, AdamW, Optimizer + +try: + from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon" +except ImportError as e: + from torch.optim.optimizer import ParamsT from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version import refiners.fluxion.layers as fl @@ -119,17 +123,17 @@ class OptimizerConfig(BaseModel): eps: float = 1e-8 weight_decay: float = 0.0 - def get(self, model_parameters: Iterable[Parameter]) -> Optimizer: + def get(self, params: ParamsT) -> Optimizer: match self.optimizer: case Optimizers.SGD: return SGD( - params=model_parameters, + params=params, lr=self.learning_rate, weight_decay=self.weight_decay, ) case Optimizers.Adam: return Adam( - params=model_parameters, + params=params, lr=self.learning_rate, betas=self.betas, eps=self.eps, @@ -137,7 +141,7 @@ class OptimizerConfig(BaseModel): ) case Optimizers.AdamW: return AdamW( - params=model_parameters, + params=params, lr=self.learning_rate, betas=self.betas, eps=self.eps, @@ -145,7 +149,7 @@ class OptimizerConfig(BaseModel): ) case Optimizers.AdamW8bit: return AdamW8bit( - params=model_parameters, + params=params, lr=self.learning_rate, betas=self.betas, eps=self.eps, @@ -153,7 +157,7 @@ class OptimizerConfig(BaseModel): ) case Optimizers.Lion8bit: return Lion8bit( - params=model_parameters, + params=params, lr=self.learning_rate, betas=self.betas, weight_decay=self.weight_decay, # type: ignore @@ -163,7 +167,7 @@ class OptimizerConfig(BaseModel): warn("Prodigy learning rate is not 1.0, this might cause instability.") return Prodigy( lr=self.learning_rate, - params=model_parameters, + params=params, betas=self.betas, weight_decay=self.weight_decay, # type: ignore safeguard_warmup=True, @@ -174,7 +178,10 @@ class OptimizerConfig(BaseModel): class ModelConfig(BaseModel): checkpoint: Path | None = None train: bool = True - learning_rate: float | None = None # TODO: Implement this + learning_rate: float | None = None + betas: tuple[float, float] | None = None + eps: float | None = None + weight_decay: float | None = None class GyroDropoutConfig(BaseModel): diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 46ae37a..d06106e 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -329,6 +329,32 @@ class Trainer(Generic[ConfigType, Batch], ABC): """Returns a list of learnable parameters in all models""" return [param for model in self.models.values() for param in model.parameters() if param.requires_grad] + @property + def optimizer_parameters(self) -> list[dict[str, Any]]: + """ + Returns a list of `dict`-s containing the params and optimizer options for each model. + See https://pytorch.org/docs/stable/optim.html#per-parameter-options for more details + """ + params: list[dict[str, Any]] = [] + for model_name, model in self.models.items(): + model_params = [param for param in model.parameters() if param.requires_grad] + model_config = self.config.models[model_name] + model_optim_conf: dict[str, Any] = {} + + if model_config.learning_rate is not None: + model_optim_conf["lr"] = model_config.learning_rate + if model_config.weight_decay is not None: + model_optim_conf["weight_decay"] = model_config.learning_rate + if model_config.betas is not None: + model_optim_conf["betas"] = model_config.learning_rate + if model_config.eps is not None: + model_optim_conf["eps"] = model_config.learning_rate + + for param in model_params: + params.append({"params": param, **model_optim_conf}) + + return params + @property def learnable_parameter_count(self) -> int: """Returns the number of learnable parameters in all models""" @@ -353,7 +379,8 @@ class Trainer(Generic[ConfigType, Batch], ABC): def optimizer(self) -> Optimizer: formatted_param_count = human_readable_number(number=self.learnable_parameter_count) logger.info(f"Total number of learnable parameters in the model(s): {formatted_param_count}") - optimizer = self.config.optimizer.get(model_parameters=self.learnable_parameters) + + optimizer = self.config.optimizer.get(params=self.optimizer_parameters) return optimizer @cached_property diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml new file mode 100644 index 0000000..5a126fb --- /dev/null +++ b/tests/training_utils/mock_config_2_models.toml @@ -0,0 +1,34 @@ +[models.mock_model1] +train = true +learning_rate = 1e-5 + +[models.mock_model2] +train = true + +[training] +duration = "100:epoch" +seed = 0 +batch_size = 4 +gradient_accumulation = "4:step" +clip_grad_norm = 1.0 +evaluation_interval = "5:epoch" +evaluation_seed = 1 + +[optimizer] +optimizer = "SGD" +learning_rate = 1 +momentum = 0.9 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" + +[dropout] +dropout = 0.0 + +[checkpointing] +save_interval = "10:epoch" + +[wandb] +mode = "disabled" +project = "mock_project" diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 6e8a493..62a6041 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -214,3 +214,39 @@ def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None: optimizer = warmup_scheduler.optimizer for group in optimizer.param_groups: assert group["lr"] == 0.1 + + +class MockTrainerWith2Models(MockTrainer): + @cached_property + def mock_model1(self) -> MockModel: + return MockModel() + + @cached_property + def mock_model2(self) -> MockModel: + return MockModel() + + def load_models(self) -> dict[str, fl.Module]: + return {"mock_model1": self.mock_model1, "mock_model2": self.mock_model2} + + def compute_loss(self, batch: MockBatch) -> Tensor: + self.step_counter += 1 + inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device) + outputs = self.mock_model2(self.mock_model1(inputs)) + return norm(outputs - targets) + + +@pytest.fixture +def mock_config_2_models(test_device: torch.device) -> MockConfig: + return MockConfig.load_from_toml(Path(__file__).parent / "mock_config_2_models.toml") + + +@pytest.fixture +def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2Models: + return MockTrainerWith2Models(config=mock_config_2_models) + + +def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None: + assert ( + len(mock_trainer_2_models.optimizer.param_groups) == 12 + ) # 12 == (3 [linear layers] * 2 [bias + weights]) * 2 [models] + assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5