mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
parent
d590c0e2fa
commit
5a7085bb3a
|
@ -1,23 +1,24 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import warn
|
from logging import warn
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Literal, Type, TypeVar
|
from typing import Any, Callable, Iterable, Literal, Type, TypeVar
|
||||||
|
|
||||||
import tomli
|
import tomli
|
||||||
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||||
from prodigyopt import Prodigy # type: ignore
|
from prodigyopt import Prodigy # type: ignore
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.optim.optimizer import params_t as ParamsT # PyTorch 2.1. TODO: remove "soon"
|
|
||||||
except ImportError as e:
|
|
||||||
from torch.optim.optimizer import ParamsT
|
|
||||||
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
from typing_extensions import TypedDict # https://errors.pydantic.dev/2.0b3/u/typed-dict-version
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
|
from refiners.training_utils.dropout import apply_dropout, apply_gyro_dropout
|
||||||
|
|
||||||
|
# PyTorch optimizer parameters type
|
||||||
|
# TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced
|
||||||
|
# See https://github.com/pytorch/pytorch/pull/111114
|
||||||
|
ParamsT = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"parse_number_unit_field",
|
"parse_number_unit_field",
|
||||||
"TimeUnit",
|
"TimeUnit",
|
||||||
|
|
Loading…
Reference in a new issue