annotated validators for TimeValue

This commit is contained in:
limiteinductive 2024-04-25 13:32:04 +00:00 committed by Benjamin Trom
parent 22f4f4faf1
commit 0bec9a855d
2 changed files with 15 additions and 27 deletions

View file

@ -1,17 +1,17 @@
from enum import Enum from enum import Enum
from logging import warn from logging import warn
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Type, TypeVar from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar
import tomli import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, BeforeValidator, ConfigDict
from torch import Tensor from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field
# PyTorch optimizer parameters type # PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
@ -19,20 +19,21 @@ from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, Ti
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)]
IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)]
StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)]
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
device: str = "cpu" device: str = "cpu"
dtype: str = "float32" dtype: str = "float32"
duration: TimeValue = Iteration(1) duration: TimeValueField = Iteration(1)
seed: int = 0 seed: int = 0
gradient_accumulation: Step = Step(1) gradient_accumulation: StepField = Step(1)
gradient_clipping_max_norm: float | None = None gradient_clipping_max_norm: float | None = None
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@field_validator("duration", "gradient_accumulation", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
class Optimizers(str, Enum): class Optimizers(str, Enum):
SGD = "SGD" SGD = "SGD"
@ -60,8 +61,8 @@ class LRSchedulerType(str, Enum):
class LRSchedulerConfig(BaseModel): class LRSchedulerConfig(BaseModel):
type: LRSchedulerType = LRSchedulerType.DEFAULT type: LRSchedulerType = LRSchedulerType.DEFAULT
update_interval: Iteration | Epoch = Iteration(1) update_interval: IterationOrEpochField = Iteration(1)
warmup: TimeValue = Iteration(0) warmup: TimeValueField = Iteration(0)
gamma: float = 0.1 gamma: float = 0.1
lr_lambda: Callable[[int], float] | None = None lr_lambda: Callable[[int], float] | None = None
mode: Literal["min", "max"] = "min" mode: Literal["min", "max"] = "min"
@ -77,10 +78,6 @@ class LRSchedulerConfig(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@field_validator("update_interval", "warmup", mode="before")
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
class OptimizerConfig(BaseModel): class OptimizerConfig(BaseModel):
optimizer: Optimizers optimizer: Optimizers

View file

@ -6,7 +6,6 @@ from typing import cast
import pytest import pytest
import torch import torch
from pydantic import field_validator
from torch import Tensor, nn from torch import Tensor, nn
from torch.optim import SGD from torch.optim import SGD
@ -16,16 +15,12 @@ from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import ( from refiners.training_utils.common import (
Epoch, Epoch,
Iteration,
Step, Step,
TimeValue,
TimeValueInput,
count_learnable_parameters, count_learnable_parameters,
human_readable_number, human_readable_number,
parse_number_unit_field,
scoped_seed, scoped_seed,
) )
from refiners.training_utils.config import BaseConfig, ModelConfig from refiners.training_utils.config import BaseConfig, IterationOrEpochField, ModelConfig, TimeValueField
from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader
from refiners.training_utils.trainer import ( from refiners.training_utils.trainer import (
Trainer, Trainer,
@ -49,13 +44,9 @@ class MockModelConfig(ModelConfig):
class MockCallbackConfig(CallbackConfig): class MockCallbackConfig(CallbackConfig):
on_batch_end_interval: Step | Iteration | Epoch on_batch_end_interval: TimeValueField
on_batch_end_seed: int on_batch_end_seed: int
on_optimizer_step_interval: Iteration | Epoch on_optimizer_step_interval: IterationOrEpochField
@field_validator("on_batch_end_interval", "on_optimizer_step_interval", mode="before")
def parse_field(cls, value: TimeValueInput) -> TimeValue:
return parse_number_unit_field(value)
class MockConfig(BaseConfig): class MockConfig(BaseConfig):