diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index afaa641..4960713 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -1,23 +1,24 @@ from enum import Enum from logging import warn from pathlib import Path -from typing import Any, Callable, Literal, Type, TypeVar +from typing import Any, Callable, Iterable, Literal, Type, TypeVar import tomli from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore from pydantic import BaseModel, validator +from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer - -try: - from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon" -except ImportError as e: - from torch.optim.optimizer import ParamsT from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version import refiners.fluxion.layers as fl from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout +# PyTorch optimizer parameters type +# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced +# See https://github.com/pytorch/pytorch/pull/111114 +ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] + __all__ = [ "parse_number_unit_field", "TimeUnit",