fix typing issues coming from torch 2.4 version ; typing is not guaranteed for torch < 2.4

This commit is contained in:
limiteinductive 2024-08-02 09:56:37 +00:00 committed by Benjamin Trom
parent 10fba3467a
commit 1de567590b
14 changed files with 115 additions and 92 deletions

View file

@ -10,7 +10,9 @@
# universal: false # universal: false
-e file:. -e file:.
aiohttp==3.9.5 aiohappyeyeballs==2.3.4
# via aiohttp
aiohttp==3.10.0
# via datasets # via datasets
# via fsspec # via fsspec
aiosignal==1.3.1 aiosignal==1.3.1
@ -27,20 +29,22 @@ attrs==23.2.0
# via referencing # via referencing
babel==2.15.0 babel==2.15.0
# via mkdocs-material # via mkdocs-material
bitsandbytes==0.43.1 backports-strenum==1.3.1
# via griffe
bitsandbytes==0.43.3
# via refiners # via refiners
black==24.4.2 black==24.4.2
# via refiners # via refiners
boto3==1.34.131 boto3==1.34.152
# via neptune # via neptune
botocore==1.34.131 botocore==1.34.152
# via boto3 # via boto3
# via s3transfer # via s3transfer
bravado==11.0.3 bravado==11.0.3
# via neptune # via neptune
bravado-core==6.1.1 bravado-core==6.1.1
# via bravado # via bravado
certifi==2024.6.2 certifi==2024.7.4
# via requests # via requests
# via sentry-sdk # via sentry-sdk
charset-normalizer==3.3.2 charset-normalizer==3.3.2
@ -56,7 +60,7 @@ colorama==0.4.6
# via mkdocs-material # via mkdocs-material
datasets==2.20.0 datasets==2.20.0
# via refiners # via refiners
diffusers==0.29.1 diffusers==0.29.2
# via refiners # via refiners
dill==0.3.8 dill==0.3.8
# via datasets # via datasets
@ -89,9 +93,9 @@ gitpython==3.1.43
# via neptune # via neptune
# via refiners # via refiners
# via wandb # via wandb
griffe==0.47.0 griffe==0.48.0
# via mkdocstrings-python # via mkdocstrings-python
huggingface-hub==0.23.4 huggingface-hub==0.24.5
# via datasets # via datasets
# via diffusers # via diffusers
# via timm # via timm
@ -101,11 +105,13 @@ idna==3.7
# via jsonschema # via jsonschema
# via requests # via requests
# via yarl # via yarl
importlib-metadata==7.2.1 importlib-metadata==8.2.0
# via diffusers # via diffusers
importlib-resources==6.4.0
# via swagger-spec-validator
isoduration==20.11.0 isoduration==20.11.0
# via jsonschema # via jsonschema
jaxtyping==0.2.30 jaxtyping==0.2.33
# via refiners # via refiners
jinja2==3.1.4 jinja2==3.1.4
# via mkdocs # via mkdocs
@ -119,7 +125,7 @@ jsonpointer==3.0.0
# via jsonschema # via jsonschema
jsonref==1.1.0 jsonref==1.1.0
# via bravado-core # via bravado-core
jsonschema==4.22.0 jsonschema==4.23.0
# via bravado-core # via bravado-core
# via swagger-spec-validator # via swagger-spec-validator
jsonschema-specifications==2023.12.1 jsonschema-specifications==2023.12.1
@ -151,14 +157,14 @@ mkdocs-get-deps==0.2.0
# via mkdocs # via mkdocs
mkdocs-literate-nav==0.6.1 mkdocs-literate-nav==0.6.1
# via refiners # via refiners
mkdocs-material==9.5.27 mkdocs-material==9.5.30
# via refiners # via refiners
mkdocs-material-extensions==1.3.1 mkdocs-material-extensions==1.3.1
# via mkdocs-material # via mkdocs-material
mkdocstrings==0.25.1 mkdocstrings==0.25.2
# via mkdocstrings-python # via mkdocstrings-python
# via refiners # via refiners
mkdocstrings-python==1.10.5 mkdocstrings-python==1.10.7
# via mkdocstrings # via mkdocstrings
monotonic==1.6 monotonic==1.6
# via bravado # via bravado
@ -178,7 +184,7 @@ neptune==1.10.4
# via refiners # via refiners
networkx==3.3 networkx==3.3
# via torch # via torch
numpy==2.0.0 numpy==2.0.1
# via bitsandbytes # via bitsandbytes
# via datasets # via datasets
# via diffusers # via diffusers
@ -197,7 +203,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105
# via torch # via torch
nvidia-cuda-runtime-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105
# via torch # via torch
nvidia-cudnn-cu12==8.9.2.26 nvidia-cudnn-cu12==9.1.0.70
# via torch # via torch
nvidia-cufft-cu12==11.0.2.54 nvidia-cufft-cu12==11.0.2.54
# via torch # via torch
@ -210,7 +216,7 @@ nvidia-cusparse-cu12==12.1.0.106
# via torch # via torch
nvidia-nccl-cu12==2.20.5 nvidia-nccl-cu12==2.20.5
# via torch # via torch
nvidia-nvjitlink-cu12==12.5.40 nvidia-nvjitlink-cu12==12.6.20
# via nvidia-cusolver-cu12 # via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12 # via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105
@ -248,24 +254,24 @@ platformdirs==4.2.2
# via wandb # via wandb
prodigyopt==1.0 prodigyopt==1.0
# via refiners # via refiners
protobuf==5.27.1 protobuf==5.27.3
# via wandb # via wandb
psutil==6.0.0 psutil==6.0.0
# via neptune # via neptune
# via wandb # via wandb
pyarrow==16.1.0 pyarrow==17.0.0
# via datasets # via datasets
pyarrow-hotfix==0.6 pyarrow-hotfix==0.6
# via datasets # via datasets
pydantic==2.7.4 pydantic==2.8.2
# via refiners # via refiners
pydantic-core==2.18.4 pydantic-core==2.20.1
# via pydantic # via pydantic
pygments==2.18.0 pygments==2.18.0
# via mkdocs-material # via mkdocs-material
pyjwt==2.8.0 pyjwt==2.9.0
# via neptune # via neptune
pymdown-extensions==10.8.1 pymdown-extensions==10.9
# via mkdocs-material # via mkdocs-material
# via mkdocstrings # via mkdocstrings
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
@ -296,7 +302,7 @@ pyyaml-env-tag==0.1
referencing==0.35.1 referencing==0.35.1
# via jsonschema # via jsonschema
# via jsonschema-specifications # via jsonschema-specifications
regex==2024.5.15 regex==2024.7.24
# via diffusers # via diffusers
# via mkdocs-material # via mkdocs-material
# via transformers # via transformers
@ -318,10 +324,10 @@ rfc3339-validator==0.1.4
# via jsonschema # via jsonschema
rfc3986-validator==0.1.1 rfc3986-validator==0.1.1
# via jsonschema # via jsonschema
rpds-py==0.18.1 rpds-py==0.19.1
# via jsonschema # via jsonschema
# via referencing # via referencing
s3transfer==0.10.1 s3transfer==0.10.2
# via boto3 # via boto3
safetensors==0.4.3 safetensors==0.4.3
# via diffusers # via diffusers
@ -330,13 +336,13 @@ safetensors==0.4.3
# via transformers # via transformers
segment-anything-hq==0.3 segment-anything-hq==0.3
# via refiners # via refiners
segment-anything-py==1.0 segment-anything-py==1.0.1
# via refiners # via refiners
sentry-sdk==2.6.0 sentry-sdk==2.12.0
# via wandb # via wandb
setproctitle==1.3.3 setproctitle==1.3.3
# via wandb # via wandb
setuptools==70.1.0 setuptools==72.1.0
# via wandb # via wandb
simplejson==3.19.2 simplejson==3.19.2
# via bravado # via bravado
@ -350,26 +356,26 @@ six==1.16.0
# via rfc3339-validator # via rfc3339-validator
smmap==5.0.1 smmap==5.0.1
# via gitdb # via gitdb
swagger-spec-validator==3.0.3 swagger-spec-validator==3.0.4
# via bravado-core # via bravado-core
# via neptune # via neptune
sympy==1.12.1 sympy==1.13.1
# via torch # via torch
timm==1.0.7 timm==1.0.8
# via refiners # via refiners
tokenizers==0.19.1 tokenizers==0.19.1
# via transformers # via transformers
tomli==2.0.1 tomli==2.0.1
# via black # via black
# via refiners # via refiners
torch==2.3.1 torch==2.4.0
# via bitsandbytes # via bitsandbytes
# via refiners # via refiners
# via segment-anything-hq # via segment-anything-hq
# via segment-anything-py # via segment-anything-py
# via timm # via timm
# via torchvision # via torchvision
torchvision==0.18.1 torchvision==0.19.0
# via piq # via piq
# via refiners # via refiners
# via segment-anything-hq # via segment-anything-hq
@ -380,9 +386,9 @@ tqdm==4.66.4
# via huggingface-hub # via huggingface-hub
# via refiners # via refiners
# via transformers # via transformers
transformers==4.41.2 transformers==4.43.3
# via refiners # via refiners
triton==2.3.1 triton==3.0.0
# via torch # via torch
typeguard==2.13.3 typeguard==2.13.3
# via jaxtyping # via jaxtyping
@ -406,7 +412,7 @@ urllib3==2.2.2
# via neptune # via neptune
# via requests # via requests
# via sentry-sdk # via sentry-sdk
wandb==0.17.2 wandb==0.17.5
# via refiners # via refiners
watchdog==4.0.1 watchdog==4.0.1
# via mkdocs # via mkdocs

View file

@ -1,5 +1,6 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import NamedTuple, cast
import torch import torch
from torch import nn from torch import nn
@ -20,6 +21,20 @@ class Args(argparse.Namespace):
threshold: float threshold: float
class CLIPImageEncoderConfig(NamedTuple):
architectures: list[str]
num_channels: int
hidden_size: int
hidden_act: str
image_size: int
projection_dim: int
patch_size: int
num_hidden_layers: int
num_attention_heads: int
intermediate_size: int
layer_norm_eps: float
def setup_converter(args: Args) -> ModelConverter: def setup_converter(args: Args) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
@ -28,34 +43,26 @@ def setup_converter(args: Args) -> ModelConverter:
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
) )
assert isinstance(source, nn.Module), "Source model is not a nn.Module" assert isinstance(source, nn.Module), "Source model is not a nn.Module"
architecture: str = source.config.architectures[0] # type: ignore config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
num_channels: int = source.config.num_channels # type: ignore
embedding_dim: int = source.config.hidden_size # type: ignore
image_size: int = source.config.image_size # type: ignore
patch_size: int = source.config.patch_size # type: ignore
output_dim: int = source.config.projection_dim # type: ignore
num_layers: int = source.config.num_hidden_layers # type: ignore
num_attention_heads: int = source.config.num_attention_heads # type: ignore
feedforward_dim: int = source.config.intermediate_size # type: ignore
activation: str = source.config.hidden_act # type: ignore
layer_norm_eps: float = source.config.layer_norm_eps # type: ignore
assert architecture == "CLIPVisionModelWithProjection", f"Unsupported architecture: {architecture}" assert (
assert num_channels == 3, f"Expected 3 input channels, got {num_channels}" config.architectures[0] == "CLIPVisionModelWithProjection"
assert activation == "gelu", f"Unsupported activation: {activation}" ), f"Unsupported architecture: {config.architectures[0]}"
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
target = CLIPImageEncoder( target = CLIPImageEncoder(
image_size=image_size, image_size=config.image_size,
embedding_dim=embedding_dim, embedding_dim=config.hidden_size,
output_dim=output_dim, output_dim=config.projection_dim,
patch_size=patch_size, patch_size=config.patch_size,
num_layers=num_layers, num_layers=config.num_hidden_layers,
num_attention_heads=num_attention_heads, num_attention_heads=config.num_attention_heads,
feedforward_dim=feedforward_dim, feedforward_dim=config.intermediate_size,
layer_norm_eps=layer_norm_eps, layer_norm_eps=config.layer_norm_eps,
) )
x = torch.randn(1, 3, image_size, image_size) x = torch.randn(1, 3, config.image_size, config.image_size)
converter = ModelConverter(source_model=source, target_model=target, verbose=True) converter = ModelConverter(source_model=source, target_model=target, verbose=True)

View file

@ -1,6 +1,6 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import cast from typing import NamedTuple, cast
from torch import nn from torch import nn
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
@ -21,6 +21,18 @@ class Args(argparse.Namespace):
verbose: bool verbose: bool
class CLIPTextEncoderConfig(NamedTuple):
architectures: list[str]
vocab_size: int
hidden_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
hidden_act: str
layer_norm_eps: float
projection_dim: int
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter: def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
@ -30,19 +42,17 @@ def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
) )
assert isinstance(source, nn.Module), "Source model is not a nn.Module" assert isinstance(source, nn.Module), "Source model is not a nn.Module"
architecture: str = source.config.architectures[0] # type: ignore config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
embedding_dim: int = source.config.hidden_size # type: ignore architecture: str = config.architectures[0]
projection_dim: int = source.config.projection_dim # type: ignore embedding_dim: int = config.hidden_size
num_layers: int = source.config.num_hidden_layers # type: ignore projection_dim: int = config.projection_dim
num_attention_heads: int = source.config.num_attention_heads # type: ignore use_quick_gelu = config.hidden_act == "quick_gelu"
feed_forward_dim: int = source.config.intermediate_size # type: ignore
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}" assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
target = CLIPTextEncoder( target = CLIPTextEncoder(
embedding_dim=embedding_dim, embedding_dim=config.hidden_size,
num_layers=num_layers, num_layers=config.num_hidden_layers,
num_attention_heads=num_attention_heads, num_attention_heads=config.num_attention_heads,
feedforward_dim=feed_forward_dim, feedforward_dim=config.intermediate_size,
use_quick_gelu=use_quick_gelu, use_quick_gelu=use_quick_gelu,
) )
if architecture == "CLIPTextModelWithProjection": if architecture == "CLIPTextModelWithProjection":

View file

@ -25,7 +25,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype) zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default ) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight) self.old_weight = cast(Parameter, target.weight)
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 self.new_weight = p
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings # Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor: def lookup(self, x: Tensor) -> Tensor:
@ -35,7 +35,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
def add_embedding(self, embedding: Tensor) -> None: def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],) assert embedding.shape == (self.old_weight.shape[1],)
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])) p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 self.new_weight = p
@property @property
def num_embeddings(self) -> int: def num_embeddings(self) -> int:

View file

@ -1,5 +1,4 @@
from math import sqrt from math import sqrt
from typing import cast
import torch import torch
from torch import Tensor from torch import Tensor
@ -123,8 +122,7 @@ class LayerScale(fl.WeightedModule):
), ),
) )
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 self.register_parameter(name="weight", param=p)
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return x * self.weight return x * self.weight

View file

@ -56,7 +56,7 @@ class Euler(Solver):
def _generate_sigmas(self) -> Tensor: def _generate_sigmas(self) -> Tensor:
"""Generate the sigmas used by the solver.""" """Generate the sigmas used by the solver."""
sigmas = self.noise_std / self.cumulative_scale_factors sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
sigmas = torch.cat([sigmas, tensor([0.0])]) sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype) return sigmas.to(device=self.device, dtype=self.dtype)

View file

@ -182,8 +182,7 @@ class MaskEncoder(fl.Chain):
), ),
) )
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype)) p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 self.register_parameter("no_mask_embedding", p)
self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))
def get_no_mask_dense_embedding( def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1 self, image_embedding_size: tuple[int, int], batch_size: int = 1

View file

@ -8,7 +8,10 @@ from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
from prodigyopt import Prodigy # type: ignore from prodigyopt import Prodigy # type: ignore
from pydantic import BaseModel, BeforeValidator, ConfigDict from pydantic import BaseModel, BeforeValidator, ConfigDict
from torch import Tensor from torch import Tensor
from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD
from refiners.training_utils.clock import ClockConfig from refiners.training_utils.clock import ClockConfig
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field

View file

@ -7,7 +7,6 @@ import torch
from loguru import logger from loguru import logger
from torch import Tensor, device as Device, dtype as DType, nn from torch import Tensor, device as Device, dtype as DType, nn
from torch.autograd import backward from torch.autograd import backward
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ( from torch.optim.lr_scheduler import (
CosineAnnealingLR, CosineAnnealingLR,
CosineAnnealingWarmRestarts, CosineAnnealingWarmRestarts,
@ -21,6 +20,7 @@ from torch.optim.lr_scheduler import (
ReduceLROnPlateau, ReduceLROnPlateau,
StepLR, StepLR,
) )
from torch.optim.optimizer import Optimizer
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.training_utils.callback import ( from refiners.training_utils.callback import (
@ -226,7 +226,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma) lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
case LRSchedulerType.COSINE_ANNEALING_LR: case LRSchedulerType.COSINE_ANNEALING_LR:
lr_scheduler = CosineAnnealingLR( lr_scheduler = CosineAnnealingLR(
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min optimizer=self.optimizer,
T_max=scheduler_step_size,
eta_min=config.eta_min, # pyright: ignore[reportArgumentType]
) )
case LRSchedulerType.REDUCE_LR_ON_PLATEAU: case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
lr_scheduler = cast( lr_scheduler = cast(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 233 KiB

After

Width:  |  Height:  |  Size: 234 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 315 KiB

After

Width:  |  Height:  |  Size: 315 KiB

View file

@ -12,7 +12,7 @@ from segment_anything_hq import ( # type: ignore
) )
from segment_anything_hq.modeling.sam import Sam # type: ignore from segment_anything_hq.modeling.sam import Sam # type: ignore
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
from torch import optim from torch.optim.sgd import SGD
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.segment_anything.hq_sam import ( from refiners.foundationals.segment_anything.hq_sam import (
@ -125,13 +125,13 @@ def test_mask_decoder_tokens_extender() -> None:
case _: case _:
raise ValueError raise ValueError
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10) optimizer = SGD(mask_decoder_tokens.parameters(), lr=10)
optimizer.zero_grad() optimizer.zero_grad()
ones = torch.ones_like(tokens_before) ones = torch.ones_like(tokens_before)
loss = torch.nn.functional.mse_loss(tokens_before, ones) loss = torch.nn.functional.mse_loss(tokens_before, ones)
loss.backward() # type: ignore loss.backward() # pyright: ignore[reportUnknownMemberType]
optimizer.step() optimizer.step() # pyright: ignore[reportUnknownMemberType]
tokens_after = mask_decoder_tokens() tokens_after = mask_decoder_tokens()

View file

@ -379,7 +379,7 @@ def test_predictor_dense_mask(
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
assert np.allclose( assert np.allclose(
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1 mask_input.cpu(), facebook_mask_input, atol=1e-1
) # Lower doesn't pass, but it's close enough for logits ) # Lower doesn't pass, but it's close enough for logits
refiners_prompt = SAMPrompt( refiners_prompt = SAMPrompt(

View file

@ -2,12 +2,11 @@ import random
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import cast
import pytest import pytest
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.optim import SGD from torch.optim.sgd import SGD
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import norm from refiners.fluxion.utils import norm
@ -188,8 +187,7 @@ def test_count_learnable_parameters_with_params() -> None:
nn.Parameter(torch.randn(5), requires_grad=False), nn.Parameter(torch.randn(5), requires_grad=False),
nn.Parameter(torch.randn(3, 3), requires_grad=True), nn.Parameter(torch.randn(3, 3), requires_grad=True),
] ]
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736 assert count_learnable_parameters(params) == 13
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: