diff --git a/requirements.lock b/requirements.lock index 70c2dd4..a56a5a5 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,7 +10,9 @@ # universal: false -e file:. -aiohttp==3.9.5 +aiohappyeyeballs==2.3.4 + # via aiohttp +aiohttp==3.10.0 # via datasets # via fsspec aiosignal==1.3.1 @@ -27,20 +29,22 @@ attrs==23.2.0 # via referencing babel==2.15.0 # via mkdocs-material -bitsandbytes==0.43.1 +backports-strenum==1.3.1 + # via griffe +bitsandbytes==0.43.3 # via refiners black==24.4.2 # via refiners -boto3==1.34.131 +boto3==1.34.152 # via neptune -botocore==1.34.131 +botocore==1.34.152 # via boto3 # via s3transfer bravado==11.0.3 # via neptune bravado-core==6.1.1 # via bravado -certifi==2024.6.2 +certifi==2024.7.4 # via requests # via sentry-sdk charset-normalizer==3.3.2 @@ -56,7 +60,7 @@ colorama==0.4.6 # via mkdocs-material datasets==2.20.0 # via refiners -diffusers==0.29.1 +diffusers==0.29.2 # via refiners dill==0.3.8 # via datasets @@ -89,9 +93,9 @@ gitpython==3.1.43 # via neptune # via refiners # via wandb -griffe==0.47.0 +griffe==0.48.0 # via mkdocstrings-python -huggingface-hub==0.23.4 +huggingface-hub==0.24.5 # via datasets # via diffusers # via timm @@ -101,11 +105,13 @@ idna==3.7 # via jsonschema # via requests # via yarl -importlib-metadata==7.2.1 +importlib-metadata==8.2.0 # via diffusers +importlib-resources==6.4.0 + # via swagger-spec-validator isoduration==20.11.0 # via jsonschema -jaxtyping==0.2.30 +jaxtyping==0.2.33 # via refiners jinja2==3.1.4 # via mkdocs @@ -119,7 +125,7 @@ jsonpointer==3.0.0 # via jsonschema jsonref==1.1.0 # via bravado-core -jsonschema==4.22.0 +jsonschema==4.23.0 # via bravado-core # via swagger-spec-validator jsonschema-specifications==2023.12.1 @@ -151,14 +157,14 @@ mkdocs-get-deps==0.2.0 # via mkdocs mkdocs-literate-nav==0.6.1 # via refiners -mkdocs-material==9.5.27 +mkdocs-material==9.5.30 # via refiners mkdocs-material-extensions==1.3.1 # via mkdocs-material -mkdocstrings==0.25.1 +mkdocstrings==0.25.2 # via mkdocstrings-python # via refiners -mkdocstrings-python==1.10.5 +mkdocstrings-python==1.10.7 # via mkdocstrings monotonic==1.6 # via bravado @@ -178,7 +184,7 @@ neptune==1.10.4 # via refiners networkx==3.3 # via torch -numpy==2.0.0 +numpy==2.0.1 # via bitsandbytes # via datasets # via diffusers @@ -197,7 +203,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105 # via torch nvidia-cuda-runtime-cu12==12.1.105 # via torch -nvidia-cudnn-cu12==8.9.2.26 +nvidia-cudnn-cu12==9.1.0.70 # via torch nvidia-cufft-cu12==11.0.2.54 # via torch @@ -210,7 +216,7 @@ nvidia-cusparse-cu12==12.1.0.106 # via torch nvidia-nccl-cu12==2.20.5 # via torch -nvidia-nvjitlink-cu12==12.5.40 +nvidia-nvjitlink-cu12==12.6.20 # via nvidia-cusolver-cu12 # via nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 @@ -248,24 +254,24 @@ platformdirs==4.2.2 # via wandb prodigyopt==1.0 # via refiners -protobuf==5.27.1 +protobuf==5.27.3 # via wandb psutil==6.0.0 # via neptune # via wandb -pyarrow==16.1.0 +pyarrow==17.0.0 # via datasets pyarrow-hotfix==0.6 # via datasets -pydantic==2.7.4 +pydantic==2.8.2 # via refiners -pydantic-core==2.18.4 +pydantic-core==2.20.1 # via pydantic pygments==2.18.0 # via mkdocs-material -pyjwt==2.8.0 +pyjwt==2.9.0 # via neptune -pymdown-extensions==10.8.1 +pymdown-extensions==10.9 # via mkdocs-material # via mkdocstrings python-dateutil==2.9.0.post0 @@ -296,7 +302,7 @@ pyyaml-env-tag==0.1 referencing==0.35.1 # via jsonschema # via jsonschema-specifications -regex==2024.5.15 +regex==2024.7.24 # via diffusers # via mkdocs-material # via transformers @@ -318,10 +324,10 @@ rfc3339-validator==0.1.4 # via jsonschema rfc3986-validator==0.1.1 # via jsonschema -rpds-py==0.18.1 +rpds-py==0.19.1 # via jsonschema # via referencing -s3transfer==0.10.1 +s3transfer==0.10.2 # via boto3 safetensors==0.4.3 # via diffusers @@ -330,13 +336,13 @@ safetensors==0.4.3 # via transformers segment-anything-hq==0.3 # via refiners -segment-anything-py==1.0 +segment-anything-py==1.0.1 # via refiners -sentry-sdk==2.6.0 +sentry-sdk==2.12.0 # via wandb setproctitle==1.3.3 # via wandb -setuptools==70.1.0 +setuptools==72.1.0 # via wandb simplejson==3.19.2 # via bravado @@ -350,26 +356,26 @@ six==1.16.0 # via rfc3339-validator smmap==5.0.1 # via gitdb -swagger-spec-validator==3.0.3 +swagger-spec-validator==3.0.4 # via bravado-core # via neptune -sympy==1.12.1 +sympy==1.13.1 # via torch -timm==1.0.7 +timm==1.0.8 # via refiners tokenizers==0.19.1 # via transformers tomli==2.0.1 # via black # via refiners -torch==2.3.1 +torch==2.4.0 # via bitsandbytes # via refiners # via segment-anything-hq # via segment-anything-py # via timm # via torchvision -torchvision==0.18.1 +torchvision==0.19.0 # via piq # via refiners # via segment-anything-hq @@ -380,9 +386,9 @@ tqdm==4.66.4 # via huggingface-hub # via refiners # via transformers -transformers==4.41.2 +transformers==4.43.3 # via refiners -triton==2.3.1 +triton==3.0.0 # via torch typeguard==2.13.3 # via jaxtyping @@ -406,7 +412,7 @@ urllib3==2.2.2 # via neptune # via requests # via sentry-sdk -wandb==0.17.2 +wandb==0.17.5 # via refiners watchdog==4.0.1 # via mkdocs diff --git a/scripts/conversion/convert_transformers_clip_image_model.py b/scripts/conversion/convert_transformers_clip_image_model.py index 431b0f2..121b370 100644 --- a/scripts/conversion/convert_transformers_clip_image_model.py +++ b/scripts/conversion/convert_transformers_clip_image_model.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path +from typing import NamedTuple, cast import torch from torch import nn @@ -20,6 +21,20 @@ class Args(argparse.Namespace): threshold: float +class CLIPImageEncoderConfig(NamedTuple): + architectures: list[str] + num_channels: int + hidden_size: int + hidden_act: str + image_size: int + projection_dim: int + patch_size: int + num_hidden_layers: int + num_attention_heads: int + intermediate_size: int + layer_norm_eps: float + + def setup_converter(args: Args) -> ModelConverter: # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore @@ -28,34 +43,26 @@ def setup_converter(args: Args) -> ModelConverter: low_cpu_mem_usage=False, ) assert isinstance(source, nn.Module), "Source model is not a nn.Module" - architecture: str = source.config.architectures[0] # type: ignore - num_channels: int = source.config.num_channels # type: ignore - embedding_dim: int = source.config.hidden_size # type: ignore - image_size: int = source.config.image_size # type: ignore - patch_size: int = source.config.patch_size # type: ignore - output_dim: int = source.config.projection_dim # type: ignore - num_layers: int = source.config.num_hidden_layers # type: ignore - num_attention_heads: int = source.config.num_attention_heads # type: ignore - feedforward_dim: int = source.config.intermediate_size # type: ignore - activation: str = source.config.hidden_act # type: ignore - layer_norm_eps: float = source.config.layer_norm_eps # type: ignore + config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType] - assert architecture == "CLIPVisionModelWithProjection", f"Unsupported architecture: {architecture}" - assert num_channels == 3, f"Expected 3 input channels, got {num_channels}" - assert activation == "gelu", f"Unsupported activation: {activation}" + assert ( + config.architectures[0] == "CLIPVisionModelWithProjection" + ), f"Unsupported architecture: {config.architectures[0]}" + assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}" + assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}" target = CLIPImageEncoder( - image_size=image_size, - embedding_dim=embedding_dim, - output_dim=output_dim, - patch_size=patch_size, - num_layers=num_layers, - num_attention_heads=num_attention_heads, - feedforward_dim=feedforward_dim, - layer_norm_eps=layer_norm_eps, + image_size=config.image_size, + embedding_dim=config.hidden_size, + output_dim=config.projection_dim, + patch_size=config.patch_size, + num_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + feedforward_dim=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, ) - x = torch.randn(1, 3, image_size, image_size) + x = torch.randn(1, 3, config.image_size, config.image_size) converter = ModelConverter(source_model=source, target_model=target, verbose=True) diff --git a/scripts/conversion/convert_transformers_clip_text_model.py b/scripts/conversion/convert_transformers_clip_text_model.py index c6a4234..fd45767 100644 --- a/scripts/conversion/convert_transformers_clip_text_model.py +++ b/scripts/conversion/convert_transformers_clip_text_model.py @@ -1,6 +1,6 @@ import argparse from pathlib import Path -from typing import cast +from typing import NamedTuple, cast from torch import nn from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore @@ -21,6 +21,18 @@ class Args(argparse.Namespace): verbose: bool +class CLIPTextEncoderConfig(NamedTuple): + architectures: list[str] + vocab_size: int + hidden_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + hidden_act: str + layer_norm_eps: float + projection_dim: int + + def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter: # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel @@ -30,19 +42,17 @@ def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter low_cpu_mem_usage=False, ) assert isinstance(source, nn.Module), "Source model is not a nn.Module" - architecture: str = source.config.architectures[0] # type: ignore - embedding_dim: int = source.config.hidden_size # type: ignore - projection_dim: int = source.config.projection_dim # type: ignore - num_layers: int = source.config.num_hidden_layers # type: ignore - num_attention_heads: int = source.config.num_attention_heads # type: ignore - feed_forward_dim: int = source.config.intermediate_size # type: ignore - use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore + config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType] + architecture: str = config.architectures[0] + embedding_dim: int = config.hidden_size + projection_dim: int = config.projection_dim + use_quick_gelu = config.hidden_act == "quick_gelu" assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}" target = CLIPTextEncoder( - embedding_dim=embedding_dim, - num_layers=num_layers, - num_attention_heads=num_attention_heads, - feedforward_dim=feed_forward_dim, + embedding_dim=config.hidden_size, + num_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + feedforward_dim=config.intermediate_size, use_quick_gelu=use_quick_gelu, ) if architecture == "CLIPTextModelWithProjection": diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index 3a7c99a..bda9ce2 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -25,7 +25,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) ) # requires_grad=True by default self.old_weight = cast(Parameter, target.weight) - self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 + self.new_weight = p # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings def lookup(self, x: Tensor) -> Tensor: @@ -35,7 +35,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): def add_embedding(self, embedding: Tensor) -> None: assert embedding.shape == (self.old_weight.shape[1],) p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])) - self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 + self.new_weight = p @property def num_embeddings(self) -> int: diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index 7a4e489..4cf7cf0 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -1,5 +1,4 @@ from math import sqrt -from typing import cast import torch from torch import Tensor @@ -123,8 +122,7 @@ class LayerScale(fl.WeightedModule): ), ) - # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 - self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p)) + self.register_parameter(name="weight", param=p) def forward(self, x: Tensor) -> Tensor: return x * self.weight diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index 69e9da2..422a242 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -56,7 +56,7 @@ class Euler(Solver): def _generate_sigmas(self) -> Tensor: """Generate the sigmas used by the solver.""" sigmas = self.noise_std / self.cumulative_scale_factors - sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) + sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu())) sigmas = torch.cat([sigmas, tensor([0.0])]) return sigmas.to(device=self.device, dtype=self.dtype) diff --git a/src/refiners/foundationals/segment_anything/prompt_encoder.py b/src/refiners/foundationals/segment_anything/prompt_encoder.py index f5553f3..b9cb4e4 100644 --- a/src/refiners/foundationals/segment_anything/prompt_encoder.py +++ b/src/refiners/foundationals/segment_anything/prompt_encoder.py @@ -182,8 +182,7 @@ class MaskEncoder(fl.Chain): ), ) p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype)) - # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 - self.register_parameter("no_mask_embedding", cast(nn.Parameter, p)) + self.register_parameter("no_mask_embedding", p) def get_no_mask_dense_embedding( self, image_embedding_size: tuple[int, int], batch_size: int = 1 diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 4e68461..7b59415 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -8,7 +8,10 @@ from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore from prodigyopt import Prodigy # type: ignore from pydantic import BaseModel, BeforeValidator, ConfigDict from torch import Tensor -from torch.optim import SGD, Adam, AdamW, Optimizer +from torch.optim.adam import Adam +from torch.optim.adamw import AdamW +from torch.optim.optimizer import Optimizer +from torch.optim.sgd import SGD from refiners.training_utils.clock import ClockConfig from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 924fe15..adecfe3 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -7,7 +7,6 @@ import torch from loguru import logger from torch import Tensor, device as Device, dtype as DType, nn from torch.autograd import backward -from torch.optim import Optimizer from torch.optim.lr_scheduler import ( CosineAnnealingLR, CosineAnnealingWarmRestarts, @@ -21,6 +20,7 @@ from torch.optim.lr_scheduler import ( ReduceLROnPlateau, StepLR, ) +from torch.optim.optimizer import Optimizer from refiners.fluxion import layers as fl from refiners.training_utils.callback import ( @@ -226,7 +226,9 @@ class Trainer(Generic[ConfigType, Batch], ABC): lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma) case LRSchedulerType.COSINE_ANNEALING_LR: lr_scheduler = CosineAnnealingLR( - optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min + optimizer=self.optimizer, + T_max=scheduler_step_size, + eta_min=config.eta_min, # pyright: ignore[reportArgumentType] ) case LRSchedulerType.REDUCE_LR_ON_PLATEAU: lr_scheduler = cast( diff --git a/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png index 3f6f69d..fbde686 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png and b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_plus_statue.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_woman.png b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_woman.png index 8a0cbdf..e0156b4 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_woman.png and b/tests/e2e/test_diffusion_ref/expected_image_ip_adapter_woman.png differ diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py index c412dc5..5c2222a 100644 --- a/tests/foundationals/segment_anything/test_hq_sam.py +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -12,7 +12,7 @@ from segment_anything_hq import ( # type: ignore ) from segment_anything_hq.modeling.sam import Sam # type: ignore from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt -from torch import optim +from torch.optim.sgd import SGD from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad from refiners.foundationals.segment_anything.hq_sam import ( @@ -125,13 +125,13 @@ def test_mask_decoder_tokens_extender() -> None: case _: raise ValueError - optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10) + optimizer = SGD(mask_decoder_tokens.parameters(), lr=10) optimizer.zero_grad() ones = torch.ones_like(tokens_before) loss = torch.nn.functional.mse_loss(tokens_before, ones) - loss.backward() # type: ignore - optimizer.step() + loss.backward() # pyright: ignore[reportUnknownMemberType] + optimizer.step() # pyright: ignore[reportUnknownMemberType] tokens_after = mask_decoder_tokens() diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index fc40810..b18a77b 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -379,7 +379,7 @@ def test_predictor_dense_mask( mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW assert np.allclose( - mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1 + mask_input.cpu(), facebook_mask_input, atol=1e-1 ) # Lower doesn't pass, but it's close enough for logits refiners_prompt = SAMPrompt( diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 1f3342d..e370a80 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -2,12 +2,11 @@ import random import warnings from dataclasses import dataclass from pathlib import Path -from typing import cast import pytest import torch from torch import Tensor, nn -from torch.optim import SGD +from torch.optim.sgd import SGD from refiners.fluxion import layers as fl from refiners.fluxion.utils import norm @@ -188,8 +187,7 @@ def test_count_learnable_parameters_with_params() -> None: nn.Parameter(torch.randn(5), requires_grad=False), nn.Parameter(torch.randn(3, 3), requires_grad=True), ] - # cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 - assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13 + assert count_learnable_parameters(params) == 13 def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: