Allow optional train ModelConfig + forbid extra input for configs

This commit is contained in:
limiteinductive 2024-02-10 14:53:18 +00:00 committed by Benjamin Trom
parent 402d3105b4
commit f541badcb3
6 changed files with 32 additions and 19 deletions

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
import tomli import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, validator from pydantic import BaseModel, ConfigDict, validator
from torch import Tensor from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
@ -66,6 +66,8 @@ class TrainingConfig(BaseModel):
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION} evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
evaluation_seed: int = 0 evaluation_seed: int = 0
model_config = ConfigDict(extra="forbid")
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True) @validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
def parse_field(cls, value: Any) -> TimeValue: def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value) return parse_number_unit_field(value)
@ -112,6 +114,8 @@ class SchedulerConfig(BaseModel):
max_lr: float | list[float] = 0 max_lr: float | list[float] = 0
eta_min: float = 0 eta_min: float = 0
model_config = ConfigDict(extra="forbid")
@validator("update_interval", "warmup", pre=True) @validator("update_interval", "warmup", pre=True)
def parse_field(cls, value: Any) -> TimeValue: def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value) return parse_number_unit_field(value)
@ -124,6 +128,8 @@ class OptimizerConfig(BaseModel):
eps: float = 1e-8 eps: float = 1e-8
weight_decay: float = 0.0 weight_decay: float = 0.0
model_config = ConfigDict(extra="forbid")
def get(self, params: ParamsT) -> Optimizer: def get(self, params: ParamsT) -> Optimizer:
match self.optimizer: match self.optimizer:
case Optimizers.SGD: case Optimizers.SGD:
@ -178,12 +184,16 @@ class OptimizerConfig(BaseModel):
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
checkpoint: Path | None = None checkpoint: Path | None = None
train: bool = True # If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model
requires_grad: bool | None = None
learning_rate: float | None = None learning_rate: float | None = None
betas: tuple[float, float] | None = None betas: tuple[float, float] | None = None
eps: float | None = None eps: float | None = None
weight_decay: float | None = None weight_decay: float | None = None
model_config = ConfigDict(extra="forbid")
class GyroDropoutConfig(BaseModel): class GyroDropoutConfig(BaseModel):
total_subnetworks: int = 512 total_subnetworks: int = 512
@ -191,11 +201,15 @@ class GyroDropoutConfig(BaseModel):
iters_per_epoch: int = 512 iters_per_epoch: int = 512
num_features_threshold: float = 5e5 num_features_threshold: float = 5e5
model_config = ConfigDict(extra="forbid")
class DropoutConfig(BaseModel): class DropoutConfig(BaseModel):
dropout_probability: float = 0.0 dropout_probability: float = 0.0
gyro_dropout: GyroDropoutConfig | None = None gyro_dropout: GyroDropoutConfig | None = None
model_config = ConfigDict(extra="forbid")
def apply_dropout(self, model: fl.Chain) -> None: def apply_dropout(self, model: fl.Chain) -> None:
if self.dropout_probability > 0.0: if self.dropout_probability > 0.0:
if self.gyro_dropout is not None: if self.gyro_dropout is not None:
@ -214,6 +228,8 @@ class BaseConfig(BaseModel):
scheduler: SchedulerConfig scheduler: SchedulerConfig
dropout: DropoutConfig dropout: DropoutConfig
model_config = ConfigDict(extra="forbid")
@classmethod @classmethod
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T: def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
with open(file=toml_path, mode="rb") as f: with open(file=toml_path, mode="rb") as f:

View file

@ -1,7 +1,7 @@
from typing import Any, Generic, Protocol, TypeVar, cast from typing import Any, Generic, Protocol, TypeVar, cast
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
from pydantic import BaseModel # type: ignore from pydantic import BaseModel, ConfigDict # type: ignore
__all__ = ["load_hf_dataset", "HuggingfaceDataset"] __all__ = ["load_hf_dataset", "HuggingfaceDataset"]
@ -34,3 +34,5 @@ class HuggingfaceDatasetConfig(BaseModel):
use_verification: bool = False use_verification: bool = False
resize_image_min_size: int = 512 resize_image_min_size: int = 512
resize_image_max_size: int = 576 resize_image_max_size: int = 576
model_config = ConfigDict(extra="forbid")

View file

@ -459,7 +459,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
model.load_from_safetensors(tensors_path=checkpoint) model.load_from_safetensors(tensors_path=checkpoint)
else: else:
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.") logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
model.requires_grad_(requires_grad=self.config.models[model_name].train) if (requires_grad := self.config.models[model_name].requires_grad) is not None:
model.requires_grad_(requires_grad=requires_grad)
model.to(self.device) model.to(self.device)
model.zero_grad() model.zero_grad()

View file

@ -5,7 +5,7 @@ from typing import Any, Literal
import wandb import wandb
from PIL import Image from PIL import Image
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from refiners.training_utils.callback import Callback from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig
@ -87,6 +87,8 @@ class WandbConfig(BaseModel):
anonymous: Literal["never", "allow", "must"] | None = None anonymous: Literal["never", "allow", "must"] | None = None
id: str | None = None id: str | None = None
model_config = ConfigDict(extra="forbid")
AnyTrainer = Trainer[BaseConfig, Any] AnyTrainer = Trainer[BaseConfig, Any]

View file

@ -1,5 +1,6 @@
[models.mock_model] [models.mock_model]
train = true requires_grad = true
[training] [training]
duration = "100:epoch" duration = "100:epoch"
@ -15,7 +16,6 @@ evaluation_seed = 1
[optimizer] [optimizer]
optimizer = "SGD" optimizer = "SGD"
learning_rate = 1 learning_rate = 1
momentum = 0.9
[scheduler] [scheduler]
scheduler_type = "ConstantLR" scheduler_type = "ConstantLR"
@ -23,4 +23,4 @@ update_interval = "1:step"
warmup = "20:step" warmup = "20:step"
[dropout] [dropout]
dropout = 0.0 dropout_probability = 0.0

View file

@ -1,9 +1,9 @@
[models.mock_model1] [models.mock_model1]
train = true requires_grad = true
learning_rate = 1e-5 learning_rate = 1e-5
[models.mock_model2] [models.mock_model2]
train = true requires_grad = true
[training] [training]
duration = "100:epoch" duration = "100:epoch"
@ -17,18 +17,10 @@ evaluation_seed = 1
[optimizer] [optimizer]
optimizer = "SGD" optimizer = "SGD"
learning_rate = 1 learning_rate = 1
momentum = 0.9
[scheduler] [scheduler]
scheduler_type = "ConstantLR" scheduler_type = "ConstantLR"
update_interval = "1:step" update_interval = "1:step"
[dropout] [dropout]
dropout = 0.0 dropout_probability = 0.0
[checkpointing]
save_interval = "10:epoch"
[wandb]
mode = "disabled"
project = "mock_project"