mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
update deprecated validator for field_validator
This commit is contained in:
parent
ab506b4db2
commit
bec845553f
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||||
import tomli
|
import tomli
|
||||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||||
from prodigyopt import Prodigy # type: ignore
|
from prodigyopt import Prodigy # type: ignore
|
||||||
from pydantic import BaseModel, ConfigDict, validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class TrainingConfig(BaseModel):
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@validator("duration", "gradient_accumulation", "evaluation_interval", pre=True)
|
@field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before")
|
||||||
def parse_field(cls, value: Any) -> TimeValue:
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
return parse_number_unit_field(value)
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class SchedulerConfig(BaseModel):
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
@validator("update_interval", "warmup", pre=True)
|
@field_validator("update_interval", "warmup", mode="before")
|
||||||
def parse_field(cls, value: Any) -> TimeValue:
|
def parse_field(cls, value: Any) -> TimeValue:
|
||||||
return parse_number_unit_field(value)
|
return parse_number_unit_field(value)
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import wandb
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
from refiners.training_utils.config import BaseConfig
|
from refiners.training_utils.config import BaseConfig
|
||||||
from refiners.training_utils.trainer import Trainer, register_callback
|
from refiners.training_utils.trainer import Trainer, register_callback
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
# TODO: remove when https://github.com/wandb/wandb/issues/6711 gets solved
|
||||||
|
warnings.filterwarnings("ignore", category=DeprecationWarning, message="pkg_resources is deprecated as an API")
|
||||||
|
import wandb
|
||||||
|
|
||||||
number = float | int
|
number = float | int
|
||||||
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]]
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
@ -222,8 +223,14 @@ def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None:
|
||||||
for _ in range(102):
|
with warnings.catch_warnings():
|
||||||
warmup_scheduler.step()
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
category=UserWarning,
|
||||||
|
message=r"Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`",
|
||||||
|
)
|
||||||
|
for _ in range(102):
|
||||||
|
warmup_scheduler.step()
|
||||||
optimizer = warmup_scheduler.optimizer
|
optimizer = warmup_scheduler.optimizer
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
assert group["lr"] == 0.1
|
assert group["lr"] == 0.1
|
||||||
|
|
Loading…
Reference in a new issue