diff --git a/pyproject.toml b/pyproject.toml index 28275bd..55548d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,11 +54,11 @@ build-backend = "hatchling.build" [tool.rye] managed = true dev-dependencies = [ - "pyright == 1.1.342", - "ruff>=0.0.292", + "pyright==1.1.349", + "ruff>=0.1.15", "docformatter>=1.7.5", - "pytest>=7.4.2", - "mkdocs-material>=9.5.3", + "pytest>=8.0.0", + "mkdocs-material>=9.5.6", "coverage>=7.4.1", "mkdocstrings[python]>=0.24.0", ] diff --git a/requirements.lock b/requirements.lock index 02f7619..2d3c76d 100644 --- a/requirements.lock +++ b/requirements.lock @@ -8,38 +8,38 @@ # with-sources: false -e file:. -aiohttp==3.9.1 +aiohttp==3.9.3 aiosignal==1.3.1 annotated-types==0.6.0 appdirs==1.4.4 async-timeout==4.0.3 -attrs==23.1.0 -bitsandbytes==0.41.3 +attrs==23.2.0 +bitsandbytes==0.42.0 certifi==2023.11.17 charset-normalizer==3.3.2 click==8.1.7 -datasets==2.15.0 -diffusers==0.24.0 +datasets==2.16.1 +diffusers==0.25.1 dill==0.3.7 docker-pycreds==0.4.0 filelock==3.13.1 -frozenlist==1.4.0 +frozenlist==1.4.1 fsspec==2023.10.0 gitdb==4.0.11 -gitpython==3.1.40 -huggingface-hub==0.19.4 +gitpython==3.1.41 +huggingface-hub==0.20.3 idna==3.6 -importlib-metadata==7.0.0 +importlib-metadata==7.0.1 invisible-watermark==0.2.0 -jaxtyping==0.2.24 -jinja2==3.1.2 +jaxtyping==0.2.25 +jinja2==3.1.3 loguru==0.7.2 -markupsafe==2.1.3 +markupsafe==2.1.4 mpmath==1.3.0 multidict==6.0.4 multiprocess==0.70.15 networkx==3.2.1 -numpy==1.26.2 +numpy==1.26.3 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 @@ -49,49 +49,49 @@ nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.18.1 +nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 -opencv-python==4.8.1.78 +opencv-python==4.9.0.80 packaging==23.2 -pandas==2.1.4 -pillow==10.1.0 +pandas==2.2.0 +pillow==10.2.0 piq==0.8.0 prodigyopt==1.0 -protobuf==4.25.1 -psutil==5.9.6 -pyarrow==14.0.1 +protobuf==4.25.2 +psutil==5.9.8 +pyarrow==15.0.0 pyarrow-hotfix==0.6 -pydantic==2.5.2 -pydantic-core==2.14.5 +pydantic==2.6.0 +pydantic-core==2.16.1 python-dateutil==2.8.2 -pytz==2023.3.post1 +pytz==2023.4 pywavelets==1.5.0 pyyaml==6.0.1 -regex==2023.10.3 +regex==2023.12.25 requests==2.31.0 -safetensors==0.4.1 -scipy==1.11.4 +safetensors==0.4.2 +scipy==1.12.0 segment-anything-py==1.0 -sentry-sdk==1.38.0 +sentry-sdk==1.40.0 setproctitle==1.3.3 six==1.16.0 smmap==5.0.1 sympy==1.12 -tokenizers==0.15.0 +tokenizers==0.15.1 tomli==2.0.1 -torch==2.1.1 -torchvision==0.16.1 +torch==2.2.0 +torchvision==0.17.0 tqdm==4.66.1 -transformers==4.35.2 -triton==2.1.0 +transformers==4.37.2 +triton==2.2.0 typeguard==2.13.3 -typing-extensions==4.8.0 -tzdata==2023.3 -urllib3==2.1.0 -wandb==0.16.1 +typing-extensions==4.9.0 +tzdata==2023.4 +urllib3==2.2.0 +wandb==0.16.2 xxhash==3.4.1 yarl==1.9.4 zipp==3.17.0 # The following packages are considered to be unsafe in a requirements file: -setuptools==69.0.2 +setuptools==69.0.3 diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index 403b27f..3b08765 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -21,10 +21,11 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): ) -> None: with self.setup_adapter(target): super().__init__(fl.Lambda(func=self.lookup)) - self.old_weight = cast(Parameter, target.weight) - self.new_weight = Parameter( + p = Parameter( zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) ) # requires_grad=True by default + self.old_weight = cast(Parameter, target.weight) + self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings def lookup(self, x: Tensor) -> Tensor: @@ -33,9 +34,8 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): def add_embedding(self, embedding: Tensor) -> None: assert embedding.shape == (self.old_weight.shape[1],) - self.new_weight = Parameter( - cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]) - ) + p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])) + self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 @property def num_embeddings(self) -> int: diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index 46ddc19..2045fd3 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -1,3 +1,5 @@ +from typing import cast + import torch from torch import Tensor @@ -60,18 +62,18 @@ class LayerScale(fl.WeightedModule): super().__init__() self.embedding_dim = embedding_dim - self.register_parameter( - name="weight", - param=torch.nn.Parameter( - torch.full( - size=(embedding_dim,), - fill_value=init_value, - dtype=dtype, - device=device, - ), + p = torch.nn.Parameter( + torch.full( + size=(embedding_dim,), + fill_value=init_value, + dtype=dtype, + device=device, ), ) + # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 + self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p)) + def forward(self, x: Tensor) -> Tensor: return x * self.weight diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py index 222ae4e..edcd116 100644 --- a/src/refiners/foundationals/segment_anything/prompt_encoder.py +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -1,5 +1,6 @@ from collections.abc import Sequence from enum import Enum, auto +from typing import cast import torch from jaxtyping import Float, Int @@ -180,9 +181,9 @@ class MaskEncoder(fl.Chain): dtype=dtype, ), ) - self.register_parameter( - "no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype)) - ) + p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype)) + # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 + self.register_parameter("no_mask_embedding", cast(nn.Parameter, p)) def get_no_mask_dense_embedding( self, image_embedding_size: tuple[int, int], batch_size: int = 1 diff --git a/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png b/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png index 504a89a..a286bd2 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png and b/tests/e2e/test_diffusion_ref/expected_image_sdxl_ip_adapter_plus_woman.png differ diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 3d7be73..35bc8e4 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path +from typing import cast from warnings import warn import pytest @@ -100,7 +101,8 @@ def test_count_learnable_parameters_with_params() -> None: nn.Parameter(torch.randn(5), requires_grad=False), nn.Parameter(torch.randn(3, 3), requires_grad=True), ] - assert count_learnable_parameters(params) == 13 + # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 + assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13 def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: