annotated validators for TimeValue

This commit is contained in:
limiteinductive 2024-04-25 13:32:04 +00:00 committed by Benjamin Trom
parent 22f4f4faf1
commit 0bec9a855d
2 changed files with 15 additions and 27 deletions

View file

@ -1,17 +1,17 @@
from enum import Enum
from logging import warn
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar
import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, BeforeValidator, ConfigDict
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field
# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -19,20 +19,21 @@ from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, Ti
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)]
IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)]
StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)]
class TrainingConfig(BaseModel):
device: str = "cpu"
dtype: str = "float32"
duration: TimeValue = Iteration(1)
duration: TimeValueField = Iteration(1)
seed: int = 0
gradient_accumulation: Step = Step(1)
gradient_accumulation: StepField = Step(1)
gradient_clipping_max_norm: float | None = None
model_config = ConfigDict(extra="forbid")
@field_validator("duration", "gradient_accumulation", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
class Optimizers(str, Enum):
SGD = "SGD"
@ -60,8 +61,8 @@ class LRSchedulerType(str, Enum):
class LRSchedulerConfig(BaseModel):
type: LRSchedulerType = LRSchedulerType.DEFAULT
update_interval: Iteration | Epoch = Iteration(1)
warmup: TimeValue = Iteration(0)
update_interval: IterationOrEpochField = Iteration(1)
warmup: TimeValueField = Iteration(0)
gamma: float = 0.1
lr_lambda: Callable[[int], float] | None = None
mode: Literal["min", "max"] = "min"
@ -77,10 +78,6 @@ class LRSchedulerConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
@field_validator("update_interval", "warmup", mode="before")
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
class OptimizerConfig(BaseModel):
optimizer: Optimizers

View file

@ -6,7 +6,6 @@ from typing import cast
import pytest
import torch
from pydantic import field_validator
from torch import Tensor, nn
from torch.optim import SGD
@ -16,16 +15,12 @@ from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import (
Epoch,
Iteration,
Step,
TimeValue,
TimeValueInput,
count_learnable_parameters,
human_readable_number,
parse_number_unit_field,
scoped_seed,
)
from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.config import BaseConfig, IterationOrEpochField, ModelConfig, TimeValueField
from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader
from refiners.training_utils.trainer import (
Trainer,
@ -49,13 +44,9 @@ class MockModelConfig(ModelConfig):
class MockCallbackConfig(CallbackConfig):
on_batch_end_interval: Step | Iteration | Epoch
on_batch_end_interval: TimeValueField
on_batch_end_seed: int
on_optimizer_step_interval: Iteration | Epoch
@field_validator("on_batch_end_interval", "on_optimizer_step_interval", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
on_optimizer_step_interval: IterationOrEpochField
class MockConfig(BaseConfig):