update deprecated validator for field_validator

This commit is contained in:
limiteinductive 2024-02-13 17:24:53 +00:00 committed by Benjamin Trom
parent ab506b4db2
commit bec845553f
3 changed files with 18 additions and 6 deletions

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, ConfigDict, validator
from pydantic import BaseModel, ConfigDict, field_validator
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
@ -32,7 +32,7 @@ class TrainingConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)
@ -80,7 +80,7 @@ class SchedulerConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
@validator("update_interval", "warmup", pre=True)
@field_validator("update_interval", "warmup", mode="before")
def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value)

View file

@ -1,14 +1,19 @@
import warnings
from abc import ABC
from pathlib import Path
from typing import Any, Literal
import wandb
from PIL import Image
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer, register_callback
with warnings.catch_warnings():
# TODO: remove when https://github.com/wandb/wandb/issues/6711 gets solved
warnings.filterwarnings("ignore", category=DeprecationWarning, message="pkg_resources is deprecated as an API")
import wandb
number = float | int
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]

View file

@ -1,3 +1,4 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import cast
@ -222,8 +223,14 @@ def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None:
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
for _ in range(102):
warmup_scheduler.step()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`",
)
for _ in range(102):
warmup_scheduler.step()
optimizer = warmup_scheduler.optimizer
for group in optimizer.param_groups:
assert group["lr"] == 0.1