update deprecated validator for field_validator

This commit is contained in:
limiteinductive 2024-02-13 17:24:53 +00:00 committed by Benjamin Trom
parent ab506b4db2
commit bec845553f
3 changed files with 18 additions and 6 deletions

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
import tomli import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, ConfigDict, validator from pydantic import BaseModel, ConfigDict, field_validator
from torch import Tensor from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
@ -32,7 +32,7 @@ class TrainingConfig(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True) @field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
def parse_field(cls, value: Any) -> TimeValue: def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value) return parse_number_unit_field(value)
@ -80,7 +80,7 @@ class SchedulerConfig(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
@validator("update_interval", "warmup", pre=True) @field_validator("update_interval", "warmup", mode="before")
def parse_field(cls, value: Any) -> TimeValue: def parse_field(cls, value: Any) -> TimeValue:
return parse_number_unit_field(value) return parse_number_unit_field(value)

View file

@ -1,14 +1,19 @@
import warnings
from abc import ABC from abc import ABC
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
import wandb
from PIL import Image from PIL import Image
from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.config import BaseConfig from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer, register_callback from refiners.training_utils.trainer import Trainer, register_callback
with warnings.catch_warnings():
# TODO: remove when https://github.com/wandb/wandb/issues/6711 gets solved
warnings.filterwarnings("ignore", category=DeprecationWarning, message="pkg_resources is deprecated as an API")
import wandb
number = float | int number = float | int
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]] WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]

View file

@ -1,3 +1,4 @@
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
@ -222,8 +223,14 @@ def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None:
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None: def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
for _ in range(102): with warnings.catch_warnings():
warmup_scheduler.step() warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`",
)
for _ in range(102):
warmup_scheduler.step()
optimizer = warmup_scheduler.optimizer optimizer = warmup_scheduler.optimizer
for group in optimizer.param_groups: for group in optimizer.param_groups:
assert group["lr"] == 0.1 assert group["lr"] == 0.1