training_utils/config.py: inline type alias

Follow up of #227
This commit is contained in:
Cédric Deltheil 2024-02-09 13:24:15 +00:00 committed by Cédric Deltheil
parent d590c0e2fa
commit 5a7085bb3a

View file

@ -1,23 +1,24 @@
from enum import Enum from enum import Enum
from logging import warn from logging import warn
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Type, TypeVar from typing import Any, Callable, Iterable, Literal, Type, TypeVar
import tomli import tomli
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim import SGD, Adam, AdamW, Optimizer
try:
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
except ImportError as e:
from torch.optim.optimizer import ParamsT
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
# PyTorch optimizer parameters type
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
# See https://github.com/pytorch/pytorch/pull/111114
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
__all__ = [ __all__ = [
"parse_number_unit_field", "parse_number_unit_field",
"TimeUnit", "TimeUnit",