From 0bec9a855de7eaaf5a727806a416ee3b50dec781 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 25 Apr 2024 13:32:04 +0000 Subject: [PATCH] annotated validators for TimeValue --- src/refiners/training_utils/config.py | 27 ++++++++++++--------------- tests/training_utils/test_trainer.py | 15 +++------------ 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index a7d3ab6..4e68461 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -1,17 +1,17 @@ from enum import Enum from logging import warn from pathlib import Path -from typing import Any, Callable, Iterable, Literal, Type, TypeVar +from typing import Annotated, Any, Callable, Iterable, Literal, Type, TypeVar import tomli from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, BeforeValidator, ConfigDict from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer from refiners.training_utils.clock import ClockConfig -from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, TimeValueInput, parse_number_unit_field +from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced @@ -19,20 +19,21 @@ from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, Ti ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] +TimeValueField = Annotated[TimeValue, BeforeValidator(parse_number_unit_field)] +IterationOrEpochField = Annotated[Iteration | Epoch, BeforeValidator(parse_number_unit_field)] +StepField = Annotated[Step, BeforeValidator(parse_number_unit_field)] + + class TrainingConfig(BaseModel): device: str = "cpu" dtype: str = "float32" - duration: TimeValue = Iteration(1) + duration: TimeValueField = Iteration(1) seed: int = 0 - gradient_accumulation: Step = Step(1) + gradient_accumulation: StepField = Step(1) gradient_clipping_max_norm: float | None = None model_config = ConfigDict(extra="forbid") - @field_validator("duration", "gradient_accumulation", mode="before") - def parse_field(cls, value: TimeValueInput) -> TimeValue: - return parse_number_unit_field(value) - class Optimizers(str, Enum): SGD = "SGD" @@ -60,8 +61,8 @@ class LRSchedulerType(str, Enum): class LRSchedulerConfig(BaseModel): type: LRSchedulerType = LRSchedulerType.DEFAULT - update_interval: Iteration | Epoch = Iteration(1) - warmup: TimeValue = Iteration(0) + update_interval: IterationOrEpochField = Iteration(1) + warmup: TimeValueField = Iteration(0) gamma: float = 0.1 lr_lambda: Callable[[int], float] | None = None mode: Literal["min", "max"] = "min" @@ -77,10 +78,6 @@ class LRSchedulerConfig(BaseModel): model_config = ConfigDict(extra="forbid") - @field_validator("update_interval", "warmup", mode="before") - def parse_field(cls, value: Any) -> TimeValue: - return parse_number_unit_field(value) - class OptimizerConfig(BaseModel): optimizer: Optimizers diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 34dc8e9..a4ff334 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -6,7 +6,6 @@ from typing import cast import pytest import torch -from pydantic import field_validator from torch import Tensor, nn from torch.optim import SGD @@ -16,16 +15,12 @@ from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.clock import ClockConfig from refiners.training_utils.common import ( Epoch, - Iteration, Step, - TimeValue, - TimeValueInput, count_learnable_parameters, human_readable_number, - parse_number_unit_field, scoped_seed, ) -from refiners.training_utils.config import BaseConfig, ModelConfig +from refiners.training_utils.config import BaseConfig, IterationOrEpochField, ModelConfig, TimeValueField from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader from refiners.training_utils.trainer import ( Trainer, @@ -49,13 +44,9 @@ class MockModelConfig(ModelConfig): class MockCallbackConfig(CallbackConfig): - on_batch_end_interval: Step | Iteration | Epoch + on_batch_end_interval: TimeValueField on_batch_end_seed: int - on_optimizer_step_interval: Iteration | Epoch - - @field_validator("on_batch_end_interval", "on_optimizer_step_interval", mode="before") - def parse_field(cls, value: TimeValueInput) -> TimeValue: - return parse_number_unit_field(value) + on_optimizer_step_interval: IterationOrEpochField class MockConfig(BaseConfig):