diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 8547e1e..91e41e2 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Iterable, Literal, Type, TypeVar import tomli from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore -from pydantic import BaseModel, ConfigDict, validator +from pydantic import BaseModel, ConfigDict, field_validator from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer @@ -32,7 +32,7 @@ class TrainingConfig(BaseModel): model_config = ConfigDict(extra="forbid") - @validator("duration", "gradient_accumulation", "evaluation_interval", pre=True) + @field_validator("duration", "gradient_accumulation", "evaluation_interval", mode="before") def parse_field(cls, value: Any) -> TimeValue: return parse_number_unit_field(value) @@ -80,7 +80,7 @@ class SchedulerConfig(BaseModel): model_config = ConfigDict(extra="forbid") - @validator("update_interval", "warmup", pre=True) + @field_validator("update_interval", "warmup", mode="before") def parse_field(cls, value: Any) -> TimeValue: return parse_number_unit_field(value) diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index d4683d6..ac167b5 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -1,14 +1,19 @@ +import warnings from abc import ABC from pathlib import Path from typing import Any, Literal -import wandb from PIL import Image from refiners.training_utils.callback import Callback, CallbackConfig from refiners.training_utils.config import BaseConfig from refiners.training_utils.trainer import Trainer, register_callback +with warnings.catch_warnings(): + # TODO: remove when https://github.com/wandb/wandb/issues/6711 gets solved + warnings.filterwarnings("ignore", category=DeprecationWarning, message="pkg_resources is deprecated as an API") + import wandb + number = float | int WandbLoggable = number | Image.Image | list[number] | dict[str, list[number]] diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 68d5a4b..617e506 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -1,3 +1,4 @@ +import warnings from dataclasses import dataclass from pathlib import Path from typing import cast @@ -222,8 +223,14 @@ def test_initial_lr(warmup_scheduler: WarmupScheduler) -> None: def test_warmup_lr(warmup_scheduler: WarmupScheduler) -> None: - for _ in range(102): - warmup_scheduler.step() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"Detected call of `lr_scheduler.step\(\)` before `optimizer.step\(\)`", + ) + for _ in range(102): + warmup_scheduler.step() optimizer = warmup_scheduler.optimizer for group in optimizer.param_groups: assert group["lr"] == 0.1