mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
Allow optional train ModelConfig + forbid extra input for configs
This commit is contained in:
parent
402d3105b4
commit
f541badcb3
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
|||
import tomli
|
||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||
from prodigyopt import Prodigy # type: ignore
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||
|
@ -66,6 +66,8 @@ class TrainingConfig(BaseModel):
|
|||
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||
evaluation_seed: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
|
||||
def parse_field(cls, value: Any) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
@ -112,6 +114,8 @@ class SchedulerConfig(BaseModel):
|
|||
max_lr: float | list[float] = 0
|
||||
eta_min: float = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@validator("update_interval", "warmup", pre=True)
|
||||
def parse_field(cls, value: Any) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
@ -124,6 +128,8 @@ class OptimizerConfig(BaseModel):
|
|||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def get(self, params: ParamsT) -> Optimizer:
|
||||
match self.optimizer:
|
||||
case Optimizers.SGD:
|
||||
|
@ -178,12 +184,16 @@ class OptimizerConfig(BaseModel):
|
|||
|
||||
class ModelConfig(BaseModel):
|
||||
checkpoint: Path | None = None
|
||||
train: bool = True
|
||||
# If None, then requires_grad will NOT be changed when loading the model
|
||||
# this can be useful if you want to train only a part of the model
|
||||
requires_grad: bool | None = None
|
||||
learning_rate: float | None = None
|
||||
betas: tuple[float, float] | None = None
|
||||
eps: float | None = None
|
||||
weight_decay: float | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class GyroDropoutConfig(BaseModel):
|
||||
total_subnetworks: int = 512
|
||||
|
@ -191,11 +201,15 @@ class GyroDropoutConfig(BaseModel):
|
|||
iters_per_epoch: int = 512
|
||||
num_features_threshold: float = 5e5
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DropoutConfig(BaseModel):
|
||||
dropout_probability: float = 0.0
|
||||
gyro_dropout: GyroDropoutConfig | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def apply_dropout(self, model: fl.Chain) -> None:
|
||||
if self.dropout_probability > 0.0:
|
||||
if self.gyro_dropout is not None:
|
||||
|
@ -214,6 +228,8 @@ class BaseConfig(BaseModel):
|
|||
scheduler: SchedulerConfig
|
||||
dropout: DropoutConfig
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@classmethod
|
||||
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
||||
with open(file=toml_path, mode="rb") as f:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Generic, Protocol, TypeVar, cast
|
||||
|
||||
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict # type: ignore
|
||||
|
||||
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
||||
|
||||
|
@ -34,3 +34,5 @@ class HuggingfaceDatasetConfig(BaseModel):
|
|||
use_verification: bool = False
|
||||
resize_image_min_size: int = 512
|
||||
resize_image_max_size: int = 576
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
|
@ -459,7 +459,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
model.load_from_safetensors(tensors_path=checkpoint)
|
||||
else:
|
||||
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
|
||||
model.requires_grad_(requires_grad=self.config.models[model_name].train)
|
||||
if (requires_grad := self.config.models[model_name].requires_grad) is not None:
|
||||
model.requires_grad_(requires_grad=requires_grad)
|
||||
model.to(self.device)
|
||||
model.zero_grad()
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, Literal
|
|||
|
||||
import wandb
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from refiners.training_utils.callback import Callback
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
|
@ -87,6 +87,8 @@ class WandbConfig(BaseModel):
|
|||
anonymous: Literal["never", "allow", "must"] | None = None
|
||||
id: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
AnyTrainer = Trainer[BaseConfig, Any]
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
[models.mock_model]
|
||||
train = true
|
||||
requires_grad = true
|
||||
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
|
@ -15,7 +16,6 @@ evaluation_seed = 1
|
|||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
learning_rate = 1
|
||||
momentum = 0.9
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
|
@ -23,4 +23,4 @@ update_interval = "1:step"
|
|||
warmup = "20:step"
|
||||
|
||||
[dropout]
|
||||
dropout = 0.0
|
||||
dropout_probability = 0.0
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
[models.mock_model1]
|
||||
train = true
|
||||
requires_grad = true
|
||||
learning_rate = 1e-5
|
||||
|
||||
[models.mock_model2]
|
||||
train = true
|
||||
requires_grad = true
|
||||
|
||||
[training]
|
||||
duration = "100:epoch"
|
||||
|
@ -17,18 +17,10 @@ evaluation_seed = 1
|
|||
[optimizer]
|
||||
optimizer = "SGD"
|
||||
learning_rate = 1
|
||||
momentum = 0.9
|
||||
|
||||
[scheduler]
|
||||
scheduler_type = "ConstantLR"
|
||||
update_interval = "1:step"
|
||||
|
||||
[dropout]
|
||||
dropout = 0.0
|
||||
|
||||
[checkpointing]
|
||||
save_interval = "10:epoch"
|
||||
|
||||
[wandb]
|
||||
mode = "disabled"
|
||||
project = "mock_project"
|
||||
dropout_probability = 0.0
|
||||
|
|
Loading…
Reference in a new issue