mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
parent
d590c0e2fa
commit
5a7085bb3a
|
@ -1,23 +1,24 @@
|
|||
from enum import Enum
|
||||
from logging import warn
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Type, TypeVar
|
||||
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||
|
||||
import tomli
|
||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||
from prodigyopt import Prodigy # type: ignore
|
||||
from pydantic import BaseModel, validator
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
|
||||
try:
|
||||
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
|
||||
except ImportError as e:
|
||||
from torch.optim.optimizer import ParamsT
|
||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
|
||||
|
||||
# PyTorch optimizer parameters type
|
||||
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||
# See https://github.com/pytorch/pytorch/pull/111114
|
||||
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
__all__ = [
|
||||
"parse_number_unit_field",
|
||||
"TimeUnit",
|
||||
|
|
Loading…
Reference in a new issue