mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
add support for pytorch 2.2 (2.1 is still supported)
also bump all dev dependencies to their latest version
This commit is contained in:
parent
45357c5548
commit
7eb8eb4c68
|
@ -54,11 +54,11 @@ build-backend = "hatchling.build"
|
||||||
[tool.rye]
|
[tool.rye]
|
||||||
managed = true
|
managed = true
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
"pyright == 1.1.342",
|
"pyright==1.1.349",
|
||||||
"ruff>=0.0.292",
|
"ruff>=0.1.15",
|
||||||
"docformatter>=1.7.5",
|
"docformatter>=1.7.5",
|
||||||
"pytest>=7.4.2",
|
"pytest>=8.0.0",
|
||||||
"mkdocs-material>=9.5.3",
|
"mkdocs-material>=9.5.6",
|
||||||
"coverage>=7.4.1",
|
"coverage>=7.4.1",
|
||||||
"mkdocstrings[python]>=0.24.0",
|
"mkdocstrings[python]>=0.24.0",
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,38 +8,38 @@
|
||||||
# with-sources: false
|
# with-sources: false
|
||||||
|
|
||||||
-e file:.
|
-e file:.
|
||||||
aiohttp==3.9.1
|
aiohttp==3.9.3
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
annotated-types==0.6.0
|
annotated-types==0.6.0
|
||||||
appdirs==1.4.4
|
appdirs==1.4.4
|
||||||
async-timeout==4.0.3
|
async-timeout==4.0.3
|
||||||
attrs==23.1.0
|
attrs==23.2.0
|
||||||
bitsandbytes==0.41.3
|
bitsandbytes==0.42.0
|
||||||
certifi==2023.11.17
|
certifi==2023.11.17
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
datasets==2.15.0
|
datasets==2.16.1
|
||||||
diffusers==0.24.0
|
diffusers==0.25.1
|
||||||
dill==0.3.7
|
dill==0.3.7
|
||||||
docker-pycreds==0.4.0
|
docker-pycreds==0.4.0
|
||||||
filelock==3.13.1
|
filelock==3.13.1
|
||||||
frozenlist==1.4.0
|
frozenlist==1.4.1
|
||||||
fsspec==2023.10.0
|
fsspec==2023.10.0
|
||||||
gitdb==4.0.11
|
gitdb==4.0.11
|
||||||
gitpython==3.1.40
|
gitpython==3.1.41
|
||||||
huggingface-hub==0.19.4
|
huggingface-hub==0.20.3
|
||||||
idna==3.6
|
idna==3.6
|
||||||
importlib-metadata==7.0.0
|
importlib-metadata==7.0.1
|
||||||
invisible-watermark==0.2.0
|
invisible-watermark==0.2.0
|
||||||
jaxtyping==0.2.24
|
jaxtyping==0.2.25
|
||||||
jinja2==3.1.2
|
jinja2==3.1.3
|
||||||
loguru==0.7.2
|
loguru==0.7.2
|
||||||
markupsafe==2.1.3
|
markupsafe==2.1.4
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
multiprocess==0.70.15
|
multiprocess==0.70.15
|
||||||
networkx==3.2.1
|
networkx==3.2.1
|
||||||
numpy==1.26.2
|
numpy==1.26.3
|
||||||
nvidia-cublas-cu12==12.1.3.1
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
nvidia-cuda-cupti-cu12==12.1.105
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
@ -49,49 +49,49 @@ nvidia-cufft-cu12==11.0.2.54
|
||||||
nvidia-curand-cu12==10.3.2.106
|
nvidia-curand-cu12==10.3.2.106
|
||||||
nvidia-cusolver-cu12==11.4.5.107
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
nvidia-cusparse-cu12==12.1.0.106
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
nvidia-nccl-cu12==2.18.1
|
nvidia-nccl-cu12==2.19.3
|
||||||
nvidia-nvjitlink-cu12==12.3.101
|
nvidia-nvjitlink-cu12==12.3.101
|
||||||
nvidia-nvtx-cu12==12.1.105
|
nvidia-nvtx-cu12==12.1.105
|
||||||
opencv-python==4.8.1.78
|
opencv-python==4.9.0.80
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
pandas==2.1.4
|
pandas==2.2.0
|
||||||
pillow==10.1.0
|
pillow==10.2.0
|
||||||
piq==0.8.0
|
piq==0.8.0
|
||||||
prodigyopt==1.0
|
prodigyopt==1.0
|
||||||
protobuf==4.25.1
|
protobuf==4.25.2
|
||||||
psutil==5.9.6
|
psutil==5.9.8
|
||||||
pyarrow==14.0.1
|
pyarrow==15.0.0
|
||||||
pyarrow-hotfix==0.6
|
pyarrow-hotfix==0.6
|
||||||
pydantic==2.5.2
|
pydantic==2.6.0
|
||||||
pydantic-core==2.14.5
|
pydantic-core==2.16.1
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
pytz==2023.3.post1
|
pytz==2023.4
|
||||||
pywavelets==1.5.0
|
pywavelets==1.5.0
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
regex==2023.10.3
|
regex==2023.12.25
|
||||||
requests==2.31.0
|
requests==2.31.0
|
||||||
safetensors==0.4.1
|
safetensors==0.4.2
|
||||||
scipy==1.11.4
|
scipy==1.12.0
|
||||||
segment-anything-py==1.0
|
segment-anything-py==1.0
|
||||||
sentry-sdk==1.38.0
|
sentry-sdk==1.40.0
|
||||||
setproctitle==1.3.3
|
setproctitle==1.3.3
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
smmap==5.0.1
|
smmap==5.0.1
|
||||||
sympy==1.12
|
sympy==1.12
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.1
|
||||||
tomli==2.0.1
|
tomli==2.0.1
|
||||||
torch==2.1.1
|
torch==2.2.0
|
||||||
torchvision==0.16.1
|
torchvision==0.17.0
|
||||||
tqdm==4.66.1
|
tqdm==4.66.1
|
||||||
transformers==4.35.2
|
transformers==4.37.2
|
||||||
triton==2.1.0
|
triton==2.2.0
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
typing-extensions==4.8.0
|
typing-extensions==4.9.0
|
||||||
tzdata==2023.3
|
tzdata==2023.4
|
||||||
urllib3==2.1.0
|
urllib3==2.2.0
|
||||||
wandb==0.16.1
|
wandb==0.16.2
|
||||||
xxhash==3.4.1
|
xxhash==3.4.1
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
zipp==3.17.0
|
zipp==3.17.0
|
||||||
# The following packages are considered to be unsafe in a requirements file:
|
# The following packages are considered to be unsafe in a requirements file:
|
||||||
setuptools==69.0.2
|
setuptools==69.0.3
|
||||||
|
|
|
@ -21,10 +21,11 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
) -> None:
|
) -> None:
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(fl.Lambda(func=self.lookup))
|
super().__init__(fl.Lambda(func=self.lookup))
|
||||||
self.old_weight = cast(Parameter, target.weight)
|
p = Parameter(
|
||||||
self.new_weight = Parameter(
|
|
||||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||||
) # requires_grad=True by default
|
) # requires_grad=True by default
|
||||||
|
self.old_weight = cast(Parameter, target.weight)
|
||||||
|
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||||
|
|
||||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||||
def lookup(self, x: Tensor) -> Tensor:
|
def lookup(self, x: Tensor) -> Tensor:
|
||||||
|
@ -33,9 +34,8 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
|
|
||||||
def add_embedding(self, embedding: Tensor) -> None:
|
def add_embedding(self, embedding: Tensor) -> None:
|
||||||
assert embedding.shape == (self.old_weight.shape[1],)
|
assert embedding.shape == (self.old_weight.shape[1],)
|
||||||
self.new_weight = Parameter(
|
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
|
||||||
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_embeddings(self) -> int:
|
def num_embeddings(self) -> int:
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -60,18 +62,18 @@ class LayerScale(fl.WeightedModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
self.register_parameter(
|
p = torch.nn.Parameter(
|
||||||
name="weight",
|
torch.full(
|
||||||
param=torch.nn.Parameter(
|
size=(embedding_dim,),
|
||||||
torch.full(
|
fill_value=init_value,
|
||||||
size=(embedding_dim,),
|
dtype=dtype,
|
||||||
fill_value=init_value,
|
device=device,
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||||
|
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x * self.weight
|
return x * self.weight
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float, Int
|
from jaxtyping import Float, Int
|
||||||
|
@ -180,9 +181,9 @@ class MaskEncoder(fl.Chain):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.register_parameter(
|
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
||||||
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||||
)
|
self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))
|
||||||
|
|
||||||
def get_no_mask_dense_embedding(
|
def get_no_mask_dense_embedding(
|
||||||
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 1 MiB After Width: | Height: | Size: 1 MiB |
|
@ -1,6 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -100,7 +101,8 @@ def test_count_learnable_parameters_with_params() -> None:
|
||||||
nn.Parameter(torch.randn(5), requires_grad=False),
|
nn.Parameter(torch.randn(5), requires_grad=False),
|
||||||
nn.Parameter(torch.randn(3, 3), requires_grad=True),
|
nn.Parameter(torch.randn(3, 3), requires_grad=True),
|
||||||
]
|
]
|
||||||
assert count_learnable_parameters(params) == 13
|
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||||
|
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13
|
||||||
|
|
||||||
|
|
||||||
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
|
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
|
||||||
|
|
Loading…
Reference in a new issue