mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
update deprecated validator for field_validator
This commit is contained in:
parent
ab506b4db2
commit
bec845553f
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
|||
import tomli
|
||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||
from prodigyopt import Prodigy # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
|
||||
|
@ -32,7 +32,7 @@ class TrainingConfig(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
|
||||
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
|
||||
def parse_field(cls, value: Any) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
@ -80,7 +80,7 @@ class SchedulerConfig(BaseModel):
|
|||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@validator("update_interval", "warmup", pre=True)
|
||||
@field_validator("update_interval", "warmup", mode="before")
|
||||
def parse_field(cls, value: Any) -> TimeValue:
|
||||
return parse_number_unit_field(value)
|
||||
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
import warnings
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import wandb
|
||||
from PIL import Image
|
||||
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer, register_callback
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# TODO: remove when https://github.com/wandb/wandb/issues/6711 gets solved
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, message="pkg_resources is deprecated as an API")
|
||||
import wandb
|
||||
|
||||
number = float | int
|
||||
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
@ -222,8 +223,14 @@ def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None:
|
|||
|
||||
|
||||
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||
for _ in range(102):
|
||||
warmup_scheduler.step()
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=UserWarning,
|
||||
message=r"Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`",
|
||||
)
|
||||
for _ in range(102):
|
||||
warmup_scheduler.step()
|
||||
optimizer = warmup_scheduler.optimizer
|
||||
for group in optimizer.param_groups:
|
||||
assert group["lr"] == 0.1
|
||||
|
|
Loading…
Reference in a new issue