fix typing issues coming from torch 2.4 version ; typing is not guaranteed for torch < 2.4

This commit is contained in:
limiteinductive 2024-08-02 09:56:37 +00:00 committed by Benjamin Trom
parent 10fba3467a
commit 1de567590b
14 changed files with 115 additions and 92 deletions

View file

@ -10,7 +10,9 @@
# universal: false
-e file:.
aiohttp==3.9.5
aiohappyeyeballs==2.3.4
# via aiohttp
aiohttp==3.10.0
# via datasets
# via fsspec
aiosignal==1.3.1
@ -27,20 +29,22 @@ attrs==23.2.0
# via referencing
babel==2.15.0
# via mkdocs-material
bitsandbytes==0.43.1
backports-strenum==1.3.1
# via griffe
bitsandbytes==0.43.3
# via refiners
black==24.4.2
# via refiners
boto3==1.34.131
boto3==1.34.152
# via neptune
botocore==1.34.131
botocore==1.34.152
# via boto3
# via s3transfer
bravado==11.0.3
# via neptune
bravado-core==6.1.1
# via bravado
certifi==2024.6.2
certifi==2024.7.4
# via requests
# via sentry-sdk
charset-normalizer==3.3.2
@ -56,7 +60,7 @@ colorama==0.4.6
# via mkdocs-material
datasets==2.20.0
# via refiners
diffusers==0.29.1
diffusers==0.29.2
# via refiners
dill==0.3.8
# via datasets
@ -89,9 +93,9 @@ gitpython==3.1.43
# via neptune
# via refiners
# via wandb
griffe==0.47.0
griffe==0.48.0
# via mkdocstrings-python
huggingface-hub==0.23.4
huggingface-hub==0.24.5
# via datasets
# via diffusers
# via timm
@ -101,11 +105,13 @@ idna==3.7
# via jsonschema
# via requests
# via yarl
importlib-metadata==7.2.1
importlib-metadata==8.2.0
# via diffusers
importlib-resources==6.4.0
# via swagger-spec-validator
isoduration==20.11.0
# via jsonschema
jaxtyping==0.2.30
jaxtyping==0.2.33
# via refiners
jinja2==3.1.4
# via mkdocs
@ -119,7 +125,7 @@ jsonpointer==3.0.0
# via jsonschema
jsonref==1.1.0
# via bravado-core
jsonschema==4.22.0
jsonschema==4.23.0
# via bravado-core
# via swagger-spec-validator
jsonschema-specifications==2023.12.1
@ -151,14 +157,14 @@ mkdocs-get-deps==0.2.0
# via mkdocs
mkdocs-literate-nav==0.6.1
# via refiners
mkdocs-material==9.5.27
mkdocs-material==9.5.30
# via refiners
mkdocs-material-extensions==1.3.1
# via mkdocs-material
mkdocstrings==0.25.1
mkdocstrings==0.25.2
# via mkdocstrings-python
# via refiners
mkdocstrings-python==1.10.5
mkdocstrings-python==1.10.7
# via mkdocstrings
monotonic==1.6
# via bravado
@ -178,7 +184,7 @@ neptune==1.10.4
# via refiners
networkx==3.3
# via torch
numpy==2.0.0
numpy==2.0.1
# via bitsandbytes
# via datasets
# via diffusers
@ -197,7 +203,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
nvidia-cudnn-cu12==9.1.0.70
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
@ -210,7 +216,7 @@ nvidia-cusparse-cu12==12.1.0.106
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvjitlink-cu12==12.6.20
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
@ -248,24 +254,24 @@ platformdirs==4.2.2
# via wandb
prodigyopt==1.0
# via refiners
protobuf==5.27.1
protobuf==5.27.3
# via wandb
psutil==6.0.0
# via neptune
# via wandb
pyarrow==16.1.0
pyarrow==17.0.0
# via datasets
pyarrow-hotfix==0.6
# via datasets
pydantic==2.7.4
pydantic==2.8.2
# via refiners
pydantic-core==2.18.4
pydantic-core==2.20.1
# via pydantic
pygments==2.18.0
# via mkdocs-material
pyjwt==2.8.0
pyjwt==2.9.0
# via neptune
pymdown-extensions==10.8.1
pymdown-extensions==10.9
# via mkdocs-material
# via mkdocstrings
python-dateutil==2.9.0.post0
@ -296,7 +302,7 @@ pyyaml-env-tag==0.1
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2024.5.15
regex==2024.7.24
# via diffusers
# via mkdocs-material
# via transformers
@ -318,10 +324,10 @@ rfc3339-validator==0.1.4
# via jsonschema
rfc3986-validator==0.1.1
# via jsonschema
rpds-py==0.18.1
rpds-py==0.19.1
# via jsonschema
# via referencing
s3transfer==0.10.1
s3transfer==0.10.2
# via boto3
safetensors==0.4.3
# via diffusers
@ -330,13 +336,13 @@ safetensors==0.4.3
# via transformers
segment-anything-hq==0.3
# via refiners
segment-anything-py==1.0
segment-anything-py==1.0.1
# via refiners
sentry-sdk==2.6.0
sentry-sdk==2.12.0
# via wandb
setproctitle==1.3.3
# via wandb
setuptools==70.1.0
setuptools==72.1.0
# via wandb
simplejson==3.19.2
# via bravado
@ -350,26 +356,26 @@ six==1.16.0
# via rfc3339-validator
smmap==5.0.1
# via gitdb
swagger-spec-validator==3.0.3
swagger-spec-validator==3.0.4
# via bravado-core
# via neptune
sympy==1.12.1
sympy==1.13.1
# via torch
timm==1.0.7
timm==1.0.8
# via refiners
tokenizers==0.19.1
# via transformers
tomli==2.0.1
# via black
# via refiners
torch==2.3.1
torch==2.4.0
# via bitsandbytes
# via refiners
# via segment-anything-hq
# via segment-anything-py
# via timm
# via torchvision
torchvision==0.18.1
torchvision==0.19.0
# via piq
# via refiners
# via segment-anything-hq
@ -380,9 +386,9 @@ tqdm==4.66.4
# via huggingface-hub
# via refiners
# via transformers
transformers==4.41.2
transformers==4.43.3
# via refiners
triton==2.3.1
triton==3.0.0
# via torch
typeguard==2.13.3
# via jaxtyping
@ -406,7 +412,7 @@ urllib3==2.2.2
# via neptune
# via requests
# via sentry-sdk
wandb==0.17.2
wandb==0.17.5
# via refiners
watchdog==4.0.1
# via mkdocs

View file

@ -1,5 +1,6 @@
import argparse
from pathlib import Path
from typing import NamedTuple, cast
import torch
from torch import nn
@ -20,6 +21,20 @@ class Args(argparse.Namespace):
threshold: float
class CLIPImageEncoderConfig(NamedTuple):
architectures: list[str]
num_channels: int
hidden_size: int
hidden_act: str
image_size: int
projection_dim: int
patch_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
layer_norm_eps: float
def setup_converter(args: Args) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
@ -28,34 +43,26 @@ def setup_converter(args: Args) -> ModelConverter:
low_cpu_mem_usage=False,
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
architecture: str = source.config.architectures[0] # type: ignore
num_channels: int = source.config.num_channels # type: ignore
embedding_dim: int = source.config.hidden_size # type: ignore
image_size: int = source.config.image_size # type: ignore
patch_size: int = source.config.patch_size # type: ignore
output_dim: int = source.config.projection_dim # type: ignore
num_layers: int = source.config.num_hidden_layers # type: ignore
num_attention_heads: int = source.config.num_attention_heads # type: ignore
feedforward_dim: int = source.config.intermediate_size # type: ignore
activation: str = source.config.hidden_act # type: ignore
layer_norm_eps: float = source.config.layer_norm_eps # type: ignore
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
assert architecture == "CLIPVisionModelWithProjection", f"Unsupported architecture: {architecture}"
assert num_channels == 3, f"Expected 3 input channels, got {num_channels}"
assert activation == "gelu", f"Unsupported activation: {activation}"
assert (
config.architectures[0] == "CLIPVisionModelWithProjection"
), f"Unsupported architecture: {config.architectures[0]}"
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
target = CLIPImageEncoder(
image_size=image_size,
embedding_dim=embedding_dim,
output_dim=output_dim,
patch_size=patch_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
feedforward_dim=feedforward_dim,
layer_norm_eps=layer_norm_eps,
image_size=config.image_size,
embedding_dim=config.hidden_size,
output_dim=config.projection_dim,
patch_size=config.patch_size,
num_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
feedforward_dim=config.intermediate_size,
layer_norm_eps=config.layer_norm_eps,
)
x = torch.randn(1, 3, image_size, image_size)
x = torch.randn(1, 3, config.image_size, config.image_size)
converter = ModelConverter(source_model=source, target_model=target, verbose=True)

View file

@ -1,6 +1,6 @@
import argparse
from pathlib import Path
from typing import cast
from typing import NamedTuple, cast
from torch import nn
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
@ -21,6 +21,18 @@ class Args(argparse.Namespace):
verbose: bool
class CLIPTextEncoderConfig(NamedTuple):
architectures: list[str]
vocab_size: int
hidden_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
hidden_act: str
layer_norm_eps: float
projection_dim: int
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
@ -30,19 +42,17 @@ def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter
low_cpu_mem_usage=False,
)
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
architecture: str = source.config.architectures[0] # type: ignore
embedding_dim: int = source.config.hidden_size # type: ignore
projection_dim: int = source.config.projection_dim # type: ignore
num_layers: int = source.config.num_hidden_layers # type: ignore
num_attention_heads: int = source.config.num_attention_heads # type: ignore
feed_forward_dim: int = source.config.intermediate_size # type: ignore
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
architecture: str = config.architectures[0]
embedding_dim: int = config.hidden_size
projection_dim: int = config.projection_dim
use_quick_gelu = config.hidden_act == "quick_gelu"
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
target = CLIPTextEncoder(
embedding_dim=embedding_dim,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
feedforward_dim=feed_forward_dim,
embedding_dim=config.hidden_size,
num_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
feedforward_dim=config.intermediate_size,
use_quick_gelu=use_quick_gelu,
)
if architecture == "CLIPTextModelWithProjection":

View file

@ -25,7 +25,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight)
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.new_weight = p
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
@ -35,7 +35,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.new_weight = p
@property
def num_embeddings(self) -> int:

View file

@ -1,5 +1,4 @@
from math import sqrt
from typing import cast
import torch
from torch import Tensor
@ -123,8 +122,7 @@ class LayerScale(fl.WeightedModule):
),
)
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))
self.register_parameter(name="weight", param=p)
def forward(self, x: Tensor) -> Tensor:
return x * self.weight

View file

@ -56,7 +56,7 @@ class Euler(Solver):
def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)

View file

@ -182,8 +182,7 @@ class MaskEncoder(fl.Chain):
),
)
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))
self.register_parameter("no_mask_embedding", p)
def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1

View file

@ -8,7 +8,10 @@ from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, BeforeValidator, ConfigDict
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD
from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field

View file

@ -7,7 +7,6 @@ import torch
from loguru import logger
from torch import Tensor, device as Device, dtype as DType, nn
from torch.autograd import backward
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
@ -21,6 +20,7 @@ from torch.optim.lr_scheduler import (
ReduceLROnPlateau,
StepLR,
)
from torch.optim.optimizer import Optimizer
from refiners.fluxion import layers as fl
from refiners.training_utils.callback import (
@ -226,7 +226,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
case LRSchedulerType.COSINE_ANNEALING_LR:
lr_scheduler = CosineAnnealingLR(
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
optimizer=self.optimizer,
T_max=scheduler_step_size,
eta_min=config.eta_min, # pyright: ignore[reportArgumentType]
)
case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
lr_scheduler = cast(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 233 KiB

After

Width:  |  Height:  |  Size: 234 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 315 KiB

After

Width:  |  Height:  |  Size: 315 KiB

View file

@ -12,7 +12,7 @@ from segment_anything_hq import ( # type: ignore
)
from segment_anything_hq.modeling.sam import Sam # type: ignore
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
from torch import optim
from torch.optim.sgd import SGD
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.segment_anything.hq_sam import (
@ -125,13 +125,13 @@ def test_mask_decoder_tokens_extender() -> None:
case _:
raise ValueError
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
optimizer = SGD(mask_decoder_tokens.parameters(), lr=10)
optimizer.zero_grad()
ones = torch.ones_like(tokens_before)
loss = torch.nn.functional.mse_loss(tokens_before, ones)
loss.backward() # type: ignore
optimizer.step()
loss.backward() # pyright: ignore[reportUnknownMemberType]
optimizer.step() # pyright: ignore[reportUnknownMemberType]
tokens_after = mask_decoder_tokens()

View file

@ -379,7 +379,7 @@ def test_predictor_dense_mask(
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
assert np.allclose(
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1
mask_input.cpu(), facebook_mask_input, atol=1e-1
) # Lower doesn't pass, but it's close enough for logits
refiners_prompt = SAMPrompt(

View file

@ -2,12 +2,11 @@ import random
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import cast
import pytest
import torch
from torch import Tensor, nn
from torch.optim import SGD
from torch.optim.sgd import SGD
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm
@ -188,8 +187,7 @@ def test_count_learnable_parameters_with_params() -> None:
nn.Parameter(torch.randn(5), requires_grad=False),
nn.Parameter(torch.randn(3, 3), requires_grad=True),
]
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13
assert count_learnable_parameters(params) == 13
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: