From 5a7085bb3a622c1fa0fbe8398c519f4092ffcce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Fri, 9 Feb 2024 13:24:15 +0000 Subject: [PATCH] training_utils/config.py: inline type alias Follow up of #227 --- src/refiners/training_utils/config.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index afaa641..4960713 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -1,23 +1,24 @@ from enum import Enum from logging import warn from pathlib import Path -from typing import Any, Callable, Literal, Type, TypeVar +from typing import Any, Callable, Iterable, Literal, Type, TypeVar import tomli from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore from pydantic import BaseModel, validator +from torch import Tensor from torch.optim import SGD, Adam, AdamW, Optimizer - -try: - from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon" -except ImportError as e: - from torch.optim.optimizer import ParamsT from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version import refiners.fluxion.layers as fl from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout +# PyTorch optimizer parameters type +# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced +# See https://github.com/pytorch/pytorch/pull/111114 +ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]] + __all__ = [ "parse_number_unit_field", "TimeUnit",