mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Allow optional train ModelConfig + forbid extra input for configs
This commit is contained in:
parent
402d3105b4
commit
f541badcb3
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||||
import tomli
|
import tomli
|
||||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||||
from prodigyopt import Prodigy # type: ignore
|
from prodigyopt import Prodigy # type: ignore
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, ConfigDict, validator
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||||
|
@ -66,6 +66,8 @@ class TrainingConfig(BaseModel):
|
||||||
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
evaluation_interval: TimeValue = {"number": 1, "unit": TimeUnit.ITERATION}
|
||||||
evaluation_seed: int = 0
|
evaluation_seed: int = 0
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
|
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
|
||||||
def parse_field(cls, value: Any) -> TimeValue:
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
return parse_number_unit_field(value)
|
return parse_number_unit_field(value)
|
||||||
|
@ -112,6 +114,8 @@ class SchedulerConfig(BaseModel):
|
||||||
max_lr: float | list[float] = 0
|
max_lr: float | list[float] = 0
|
||||||
eta_min: float = 0
|
eta_min: float = 0
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@validator("update_interval", "warmup", pre=True)
|
@validator("update_interval", "warmup", pre=True)
|
||||||
def parse_field(cls, value: Any) -> TimeValue:
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
return parse_number_unit_field(value)
|
return parse_number_unit_field(value)
|
||||||
|
@ -124,6 +128,8 @@ class OptimizerConfig(BaseModel):
|
||||||
eps: float = 1e-8
|
eps: float = 1e-8
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
def get(self, params: ParamsT) -> Optimizer:
|
def get(self, params: ParamsT) -> Optimizer:
|
||||||
match self.optimizer:
|
match self.optimizer:
|
||||||
case Optimizers.SGD:
|
case Optimizers.SGD:
|
||||||
|
@ -178,12 +184,16 @@ class OptimizerConfig(BaseModel):
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
checkpoint: Path | None = None
|
checkpoint: Path | None = None
|
||||||
train: bool = True
|
# If None, then requires_grad will NOT be changed when loading the model
|
||||||
|
# this can be useful if you want to train only a part of the model
|
||||||
|
requires_grad: bool | None = None
|
||||||
learning_rate: float | None = None
|
learning_rate: float | None = None
|
||||||
betas: tuple[float, float] | None = None
|
betas: tuple[float, float] | None = None
|
||||||
eps: float | None = None
|
eps: float | None = None
|
||||||
weight_decay: float | None = None
|
weight_decay: float | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
class GyroDropoutConfig(BaseModel):
|
class GyroDropoutConfig(BaseModel):
|
||||||
total_subnetworks: int = 512
|
total_subnetworks: int = 512
|
||||||
|
@ -191,11 +201,15 @@ class GyroDropoutConfig(BaseModel):
|
||||||
iters_per_epoch: int = 512
|
iters_per_epoch: int = 512
|
||||||
num_features_threshold: float = 5e5
|
num_features_threshold: float = 5e5
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
class DropoutConfig(BaseModel):
|
class DropoutConfig(BaseModel):
|
||||||
dropout_probability: float = 0.0
|
dropout_probability: float = 0.0
|
||||||
gyro_dropout: GyroDropoutConfig | None = None
|
gyro_dropout: GyroDropoutConfig | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
def apply_dropout(self, model: fl.Chain) -> None:
|
def apply_dropout(self, model: fl.Chain) -> None:
|
||||||
if self.dropout_probability > 0.0:
|
if self.dropout_probability > 0.0:
|
||||||
if self.gyro_dropout is not None:
|
if self.gyro_dropout is not None:
|
||||||
|
@ -214,6 +228,8 @@ class BaseConfig(BaseModel):
|
||||||
scheduler: SchedulerConfig
|
scheduler: SchedulerConfig
|
||||||
dropout: DropoutConfig
|
dropout: DropoutConfig
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
|
||||||
with open(file=toml_path, mode="rb") as f:
|
with open(file=toml_path, mode="rb") as f:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Generic, Protocol, TypeVar, cast
|
from typing import Any, Generic, Protocol, TypeVar, cast
|
||||||
|
|
||||||
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
|
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel, ConfigDict # type: ignore
|
||||||
|
|
||||||
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
|
||||||
|
|
||||||
|
@ -34,3 +34,5 @@ class HuggingfaceDatasetConfig(BaseModel):
|
||||||
use_verification: bool = False
|
use_verification: bool = False
|
||||||
resize_image_min_size: int = 512
|
resize_image_min_size: int = 512
|
||||||
resize_image_max_size: int = 576
|
resize_image_max_size: int = 576
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
|
@ -459,7 +459,8 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
model.load_from_safetensors(tensors_path=checkpoint)
|
model.load_from_safetensors(tensors_path=checkpoint)
|
||||||
else:
|
else:
|
||||||
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
|
logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.")
|
||||||
model.requires_grad_(requires_grad=self.config.models[model_name].train)
|
if (requires_grad := self.config.models[model_name].requires_grad) is not None:
|
||||||
|
model.requires_grad_(requires_grad=requires_grad)
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Any, Literal
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from refiners.training_utils.callback import Callback
|
from refiners.training_utils.callback import Callback
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
@ -87,6 +87,8 @@ class WandbConfig(BaseModel):
|
||||||
anonymous: Literal["never", "allow", "must"] | None = None
|
anonymous: Literal["never", "allow", "must"] | None = None
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
AnyTrainer = Trainer[BaseConfig, Any]
|
AnyTrainer = Trainer[BaseConfig, Any]
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
[models.mock_model]
|
[models.mock_model]
|
||||||
train = true
|
requires_grad = true
|
||||||
|
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
|
@ -15,7 +16,6 @@ evaluation_seed = 1
|
||||||
[optimizer]
|
[optimizer]
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
learning_rate = 1
|
learning_rate = 1
|
||||||
momentum = 0.9
|
|
||||||
|
|
||||||
[scheduler]
|
[scheduler]
|
||||||
scheduler_type = "ConstantLR"
|
scheduler_type = "ConstantLR"
|
||||||
|
@ -23,4 +23,4 @@ update_interval = "1:step"
|
||||||
warmup = "20:step"
|
warmup = "20:step"
|
||||||
|
|
||||||
[dropout]
|
[dropout]
|
||||||
dropout = 0.0
|
dropout_probability = 0.0
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
[models.mock_model1]
|
[models.mock_model1]
|
||||||
train = true
|
requires_grad = true
|
||||||
learning_rate = 1e-5
|
learning_rate = 1e-5
|
||||||
|
|
||||||
[models.mock_model2]
|
[models.mock_model2]
|
||||||
train = true
|
requires_grad = true
|
||||||
|
|
||||||
[training]
|
[training]
|
||||||
duration = "100:epoch"
|
duration = "100:epoch"
|
||||||
|
@ -17,18 +17,10 @@ evaluation_seed = 1
|
||||||
[optimizer]
|
[optimizer]
|
||||||
optimizer = "SGD"
|
optimizer = "SGD"
|
||||||
learning_rate = 1
|
learning_rate = 1
|
||||||
momentum = 0.9
|
|
||||||
|
|
||||||
[scheduler]
|
[scheduler]
|
||||||
scheduler_type = "ConstantLR"
|
scheduler_type = "ConstantLR"
|
||||||
update_interval = "1:step"
|
update_interval = "1:step"
|
||||||
|
|
||||||
[dropout]
|
[dropout]
|
||||||
dropout = 0.0
|
dropout_probability = 0.0
|
||||||
|
|
||||||
[checkpointing]
|
|
||||||
save_interval = "10:epoch"
|
|
||||||
|
|
||||||
[wandb]
|
|
||||||
mode = "disabled"
|
|
||||||
project = "mock_project"
|
|
||||||
|
|
Loading…
Reference in a new issue