mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
annotated validators for TimeValue
This commit is contained in:
parent
22f4f4faf1
commit
0bec9a855d
|
@ -1,17 +1,17 @@
|
|||
from enum import Enum
|
||||
from logging import warn
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||
from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar
|
||||
|
||||
import tomli
|
||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||
from prodigyopt import Prodigy # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, BeforeValidator, ConfigDict
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field
|
||||
|
||||
# PyTorch optimizer parameters type
|
||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||
|
@ -19,20 +19,21 @@ from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, Ti
|
|||
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
|
||||
TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)]
|
||||
IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)]
|
||||
StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)]
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
device: str = "cpu"
|
||||
dtype: str = "float32"
|
||||
duration: TimeValue = Iteration(1)
|
||||
duration: TimeValueField = Iteration(1)
|
||||
seed: int = 0
|
||||
gradient_accumulation: Step = Step(1)
|
||||
gradient_accumulation: StepField = Step(1)
|
||||
gradient_clipping_max_norm: float | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("duration", "gradient_accumulation", mode="before")
|
||||
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
||||
class Optimizers(str, Enum):
|
||||
SGD = "SGD"
|
||||
|
@ -60,8 +61,8 @@ class LRSchedulerType(str, Enum):
|
|||
|
||||
class LRSchedulerConfig(BaseModel):
|
||||
type: LRSchedulerType = LRSchedulerType.DEFAULT
|
||||
update_interval: Iteration | Epoch = Iteration(1)
|
||||
warmup: TimeValue = Iteration(0)
|
||||
update_interval: IterationOrEpochField = Iteration(1)
|
||||
warmup: TimeValueField = Iteration(0)
|
||||
gamma: float = 0.1
|
||||
lr_lambda: Callable[[int], float] | None = None
|
||||
mode: Literal["min", "max"] = "min"
|
||||
|
@ -77,10 +78,6 @@ class LRSchedulerConfig(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("update_interval", "warmup", mode="before")
|
||||
def parse_field(cls, value: Any) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
||||
class OptimizerConfig(BaseModel):
|
||||
optimizer: Optimizers
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import cast
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from pydantic import field_validator
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import SGD
|
||||
|
||||
|
@ -16,16 +15,12 @@ from refiners.training_utils.callback import Callback, CallbackConfig
|
|||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import (
|
||||
Epoch,
|
||||
Iteration,
|
||||
Step,
|
||||
TimeValue,
|
||||
TimeValueInput,
|
||||
count_learnable_parameters,
|
||||
human_readable_number,
|
||||
parse_number_unit_field,
|
||||
scoped_seed,
|
||||
)
|
||||
from refiners.training_utils.config import BaseConfig, ModelConfig
|
||||
from refiners.training_utils.config import BaseConfig, IterationOrEpochField, ModelConfig, TimeValueField
|
||||
from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader
|
||||
from refiners.training_utils.trainer import (
|
||||
Trainer,
|
||||
|
@ -49,13 +44,9 @@ class MockModelConfig(ModelConfig):
|
|||
|
||||
|
||||
class MockCallbackConfig(CallbackConfig):
|
||||
on_batch_end_interval: Step | Iteration | Epoch
|
||||
on_batch_end_interval: TimeValueField
|
||||
on_batch_end_seed: int
|
||||
on_optimizer_step_interval: Iteration | Epoch
|
||||
|
||||
@field_validator("on_batch_end_interval", "on_optimizer_step_interval", mode="before")
|
||||
def parse_field(cls, value: TimeValueInput) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
on_optimizer_step_interval: IterationOrEpochField
|
||||
|
||||
|
||||
class MockConfig(BaseConfig):
|
||||
|
|
Loading…
Reference in a new issue