Allow optional train ModelConfig + forbid extra input for configs

This commit is contained in:
limiteinductive 2024-02-10 14:53:18 +00:00 committed by Benjamin Trom
parent 402d3105b4
commit f541badcb3
6 changed files with 32 additions and 19 deletions

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, validator
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
@ -66,6 +66,8 @@ class TrainingConfig(BaseModel):
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
evaluation_seed: int = 0
model_config = ConfigDict(extra="forbid")
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
@ -112,6 +114,8 @@ class SchedulerConfig(BaseModel):
max_lr: float | list[float] = 0
eta_min: float = 0
model_config = ConfigDict(extra="forbid")
@validator("update_interval", "warmup", pre=True)
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
@ -124,6 +128,8 @@ class OptimizerConfig(BaseModel):
eps: float = 1e-8
weight_decay: float = 0.0
model_config = ConfigDict(extra="forbid")
def get(self, params: ParamsT) -> Optimizer:
match self.optimizer:
case Optimizers.SGD:
@ -178,12 +184,16 @@ class OptimizerConfig(BaseModel):
class ModelConfig(BaseModel):
checkpoint: Path | None = None
train: bool = True
# If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model
requires_grad: bool | None = None
learning_rate: float | None = None
betas: tuple[float, float] | None = None
eps: float | None = None
weight_decay: float | None = None
model_config = ConfigDict(extra="forbid")
class GyroDropoutConfig(BaseModel):
total_subnetworks: int = 512
@ -191,11 +201,15 @@ class GyroDropoutConfig(BaseModel):
iters_per_epoch: int = 512
num_features_threshold: float = 5e5
model_config = ConfigDict(extra="forbid")
class DropoutConfig(BaseModel):
dropout_probability: float = 0.0
gyro_dropout: GyroDropoutConfig | None = None
model_config = ConfigDict(extra="forbid")
def apply_dropout(self, model: fl.Chain) -> None:
if self.dropout_probability > 0.0:
if self.gyro_dropout is not None:
@ -214,6 +228,8 @@ class BaseConfig(BaseModel):
scheduler: SchedulerConfig
dropout: DropoutConfig
model_config = ConfigDict(extra="forbid")
@classmethod
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
with open(file=toml_path, mode="rb") as f:

View file

@ -1,7 +1,7 @@
from typing import Any, Generic, Protocol, TypeVar, cast
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
from pydantic import BaseModel # type: ignore
from pydantic import BaseModel, ConfigDict # type: ignore
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
@ -34,3 +34,5 @@ class HuggingfaceDatasetConfig(BaseModel):
use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576
model_config = ConfigDict(extra="forbid")

View file

@ -459,7 +459,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
model.load_from_safetensors(tensors_path=checkpoint)
else:
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
model.requires_grad_(requires_grad=self.config.models[model_name].train)
if (requires_grad := self.config.models[model_name].requires_grad) is not None:
model.requires_grad_(requires_grad=requires_grad)
model.to(self.device)
model.zero_grad()

View file

@ -5,7 +5,7 @@ from typing import Any, Literal
import wandb
from PIL import Image
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig
@ -87,6 +87,8 @@ class WandbConfig(BaseModel):
anonymous: Literal["never", "allow", "must"] | None = None
id: str | None = None
model_config = ConfigDict(extra="forbid")
AnyTrainer = Trainer[BaseConfig, Any]

View file

@ -1,5 +1,6 @@
[models.mock_model]
train = true
requires_grad = true
[training]
duration = "100:epoch"
@ -15,7 +16,6 @@ evaluation_seed = 1
[optimizer]
optimizer = "SGD"
learning_rate = 1
momentum = 0.9
[scheduler]
scheduler_type = "ConstantLR"
@ -23,4 +23,4 @@ update_interval = "1:step"
warmup = "20:step"
[dropout]
dropout = 0.0
dropout_probability = 0.0

View file

@ -1,9 +1,9 @@
[models.mock_model1]
train = true
requires_grad = true
learning_rate = 1e-5
[models.mock_model2]
train = true
requires_grad = true
[training]
duration = "100:epoch"
@ -17,18 +17,10 @@ evaluation_seed = 1
[optimizer]
optimizer = "SGD"
learning_rate = 1
momentum = 0.9
[scheduler]
scheduler_type = "ConstantLR"
update_interval = "1:step"
[dropout]
dropout = 0.0
[checkpointing]
save_interval = "10:epoch"
[wandb]
mode = "disabled"
project = "mock_project"
dropout_probability = 0.0