add support for pytorch 2.2 (2.1 is still supported)

also bump all dev dependencies to their latest version
This commit is contained in:
Pierre Chapuis 2024-01-31 12:15:46 +01:00
parent 45357c5548
commit 7eb8eb4c68
7 changed files with 64 additions and 59 deletions

View file

@ -54,11 +54,11 @@ build-backend = "hatchling.build"
[tool.rye] [tool.rye]
managed = true managed = true
dev-dependencies = [ dev-dependencies = [
"pyright == 1.1.342", "pyright==1.1.349",
"ruff>=0.0.292", "ruff>=0.1.15",
"docformatter>=1.7.5", "docformatter>=1.7.5",
"pytest>=7.4.2", "pytest>=8.0.0",
"mkdocs-material>=9.5.3", "mkdocs-material>=9.5.6",
"coverage>=7.4.1", "coverage>=7.4.1",
"mkdocstrings[python]>=0.24.0", "mkdocstrings[python]>=0.24.0",
] ]

View file

@ -8,38 +8,38 @@
# with-sources: false # with-sources: false
-e file:. -e file:.
aiohttp==3.9.1 aiohttp==3.9.3
aiosignal==1.3.1 aiosignal==1.3.1
annotated-types==0.6.0 annotated-types==0.6.0
appdirs==1.4.4 appdirs==1.4.4
async-timeout==4.0.3 async-timeout==4.0.3
attrs==23.1.0 attrs==23.2.0
bitsandbytes==0.41.3 bitsandbytes==0.42.0
certifi==2023.11.17 certifi==2023.11.17
charset-normalizer==3.3.2 charset-normalizer==3.3.2
click==8.1.7 click==8.1.7
datasets==2.15.0 datasets==2.16.1
diffusers==0.24.0 diffusers==0.25.1
dill==0.3.7 dill==0.3.7
docker-pycreds==0.4.0 docker-pycreds==0.4.0
filelock==3.13.1 filelock==3.13.1
frozenlist==1.4.0 frozenlist==1.4.1
fsspec==2023.10.0 fsspec==2023.10.0
gitdb==4.0.11 gitdb==4.0.11
gitpython==3.1.40 gitpython==3.1.41
huggingface-hub==0.19.4 huggingface-hub==0.20.3
idna==3.6 idna==3.6
importlib-metadata==7.0.0 importlib-metadata==7.0.1
invisible-watermark==0.2.0 invisible-watermark==0.2.0
jaxtyping==0.2.24 jaxtyping==0.2.25
jinja2==3.1.2 jinja2==3.1.3
loguru==0.7.2 loguru==0.7.2
markupsafe==2.1.3 markupsafe==2.1.4
mpmath==1.3.0 mpmath==1.3.0
multidict==6.0.4 multidict==6.0.4
multiprocess==0.70.15 multiprocess==0.70.15
networkx==3.2.1 networkx==3.2.1
numpy==1.26.2 numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1 nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105
@ -49,49 +49,49 @@ nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106 nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107 nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106 nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1 nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.3.101 nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105
opencv-python==4.8.1.78 opencv-python==4.9.0.80
packaging==23.2 packaging==23.2
pandas==2.1.4 pandas==2.2.0
pillow==10.1.0 pillow==10.2.0
piq==0.8.0 piq==0.8.0
prodigyopt==1.0 prodigyopt==1.0
protobuf==4.25.1 protobuf==4.25.2
psutil==5.9.6 psutil==5.9.8
pyarrow==14.0.1 pyarrow==15.0.0
pyarrow-hotfix==0.6 pyarrow-hotfix==0.6
pydantic==2.5.2 pydantic==2.6.0
pydantic-core==2.14.5 pydantic-core==2.16.1
python-dateutil==2.8.2 python-dateutil==2.8.2
pytz==2023.3.post1 pytz==2023.4
pywavelets==1.5.0 pywavelets==1.5.0
pyyaml==6.0.1 pyyaml==6.0.1
regex==2023.10.3 regex==2023.12.25
requests==2.31.0 requests==2.31.0
safetensors==0.4.1 safetensors==0.4.2
scipy==1.11.4 scipy==1.12.0
segment-anything-py==1.0 segment-anything-py==1.0
sentry-sdk==1.38.0 sentry-sdk==1.40.0
setproctitle==1.3.3 setproctitle==1.3.3
six==1.16.0 six==1.16.0
smmap==5.0.1 smmap==5.0.1
sympy==1.12 sympy==1.12
tokenizers==0.15.0 tokenizers==0.15.1
tomli==2.0.1 tomli==2.0.1
torch==2.1.1 torch==2.2.0
torchvision==0.16.1 torchvision==0.17.0
tqdm==4.66.1 tqdm==4.66.1
transformers==4.35.2 transformers==4.37.2
triton==2.1.0 triton==2.2.0
typeguard==2.13.3 typeguard==2.13.3
typing-extensions==4.8.0 typing-extensions==4.9.0
tzdata==2023.3 tzdata==2023.4
urllib3==2.1.0 urllib3==2.2.0
wandb==0.16.1 wandb==0.16.2
xxhash==3.4.1 xxhash==3.4.1
yarl==1.9.4 yarl==1.9.4
zipp==3.17.0 zipp==3.17.0
# The following packages are considered to be unsafe in a requirements file: # The following packages are considered to be unsafe in a requirements file:
setuptools==69.0.2 setuptools==69.0.3

View file

@ -21,10 +21,11 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
) -> None: ) -> None:
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup)) super().__init__(fl.Lambda(func=self.lookup))
self.old_weight = cast(Parameter, target.weight) p = Parameter(
self.new_weight = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default ) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight)
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor: def lookup(self, x: Tensor) -> Tensor:
@ -33,9 +34,8 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
def add_embedding(self, embedding: Tensor) -> None: def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],) assert embedding.shape == (self.old_weight.shape[1],)
self.new_weight = Parameter( p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]) self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
)
@property @property
def num_embeddings(self) -> int: def num_embeddings(self) -> int:

View file

@ -1,3 +1,5 @@
from typing import cast
import torch import torch
from torch import Tensor from torch import Tensor
@ -60,18 +62,18 @@ class LayerScale(fl.WeightedModule):
super().__init__() super().__init__()
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.register_parameter( p = torch.nn.Parameter(
name="weight",
param=torch.nn.Parameter(
torch.full( torch.full(
size=(embedding_dim,), size=(embedding_dim,),
fill_value=init_value, fill_value=init_value,
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
),
) )
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x * self.weight return x * self.weight

View file

@ -1,5 +1,6 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, auto from enum import Enum, auto
from typing import cast
import torch import torch
from jaxtyping import Float, Int from jaxtyping import Float, Int
@ -180,9 +181,9 @@ class MaskEncoder(fl.Chain):
dtype=dtype, dtype=dtype,
), ),
) )
self.register_parameter( p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype)) # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
) self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))
def get_no_mask_dense_embedding( def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1 self, image_embedding_size: tuple[int, int], batch_size: int = 1

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1 MiB

After

Width:  |  Height:  |  Size: 1 MiB

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import cast
from warnings import warn from warnings import warn
import pytest import pytest
@ -100,7 +101,8 @@ def test_count_learnable_parameters_with_params() -> None:
nn.Parameter(torch.randn(5), requires_grad=False), nn.Parameter(torch.randn(5), requires_grad=False),
nn.Parameter(torch.randn(3, 3), requires_grad=True), nn.Parameter(torch.randn(3, 3), requires_grad=True),
] ]
assert count_learnable_parameters(params) == 13 # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: