mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix typing issues coming from torch 2.4 version ; typing is not guaranteed for torch < 2.4
This commit is contained in:
parent
10fba3467a
commit
1de567590b
|
@ -10,7 +10,9 @@
|
|||
# universal: false
|
||||
|
||||
-e file:.
|
||||
aiohttp==3.9.5
|
||||
aiohappyeyeballs==2.3.4
|
||||
# via aiohttp
|
||||
aiohttp==3.10.0
|
||||
# via datasets
|
||||
# via fsspec
|
||||
aiosignal==1.3.1
|
||||
|
@ -27,20 +29,22 @@ attrs==23.2.0
|
|||
# via referencing
|
||||
babel==2.15.0
|
||||
# via mkdocs-material
|
||||
bitsandbytes==0.43.1
|
||||
backports-strenum==1.3.1
|
||||
# via griffe
|
||||
bitsandbytes==0.43.3
|
||||
# via refiners
|
||||
black==24.4.2
|
||||
# via refiners
|
||||
boto3==1.34.131
|
||||
boto3==1.34.152
|
||||
# via neptune
|
||||
botocore==1.34.131
|
||||
botocore==1.34.152
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
bravado==11.0.3
|
||||
# via neptune
|
||||
bravado-core==6.1.1
|
||||
# via bravado
|
||||
certifi==2024.6.2
|
||||
certifi==2024.7.4
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
charset-normalizer==3.3.2
|
||||
|
@ -56,7 +60,7 @@ colorama==0.4.6
|
|||
# via mkdocs-material
|
||||
datasets==2.20.0
|
||||
# via refiners
|
||||
diffusers==0.29.1
|
||||
diffusers==0.29.2
|
||||
# via refiners
|
||||
dill==0.3.8
|
||||
# via datasets
|
||||
|
@ -89,9 +93,9 @@ gitpython==3.1.43
|
|||
# via neptune
|
||||
# via refiners
|
||||
# via wandb
|
||||
griffe==0.47.0
|
||||
griffe==0.48.0
|
||||
# via mkdocstrings-python
|
||||
huggingface-hub==0.23.4
|
||||
huggingface-hub==0.24.5
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via timm
|
||||
|
@ -101,11 +105,13 @@ idna==3.7
|
|||
# via jsonschema
|
||||
# via requests
|
||||
# via yarl
|
||||
importlib-metadata==7.2.1
|
||||
importlib-metadata==8.2.0
|
||||
# via diffusers
|
||||
importlib-resources==6.4.0
|
||||
# via swagger-spec-validator
|
||||
isoduration==20.11.0
|
||||
# via jsonschema
|
||||
jaxtyping==0.2.30
|
||||
jaxtyping==0.2.33
|
||||
# via refiners
|
||||
jinja2==3.1.4
|
||||
# via mkdocs
|
||||
|
@ -119,7 +125,7 @@ jsonpointer==3.0.0
|
|||
# via jsonschema
|
||||
jsonref==1.1.0
|
||||
# via bravado-core
|
||||
jsonschema==4.22.0
|
||||
jsonschema==4.23.0
|
||||
# via bravado-core
|
||||
# via swagger-spec-validator
|
||||
jsonschema-specifications==2023.12.1
|
||||
|
@ -151,14 +157,14 @@ mkdocs-get-deps==0.2.0
|
|||
# via mkdocs
|
||||
mkdocs-literate-nav==0.6.1
|
||||
# via refiners
|
||||
mkdocs-material==9.5.27
|
||||
mkdocs-material==9.5.30
|
||||
# via refiners
|
||||
mkdocs-material-extensions==1.3.1
|
||||
# via mkdocs-material
|
||||
mkdocstrings==0.25.1
|
||||
mkdocstrings==0.25.2
|
||||
# via mkdocstrings-python
|
||||
# via refiners
|
||||
mkdocstrings-python==1.10.5
|
||||
mkdocstrings-python==1.10.7
|
||||
# via mkdocstrings
|
||||
monotonic==1.6
|
||||
# via bravado
|
||||
|
@ -178,7 +184,7 @@ neptune==1.10.4
|
|||
# via refiners
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numpy==2.0.0
|
||||
numpy==2.0.1
|
||||
# via bitsandbytes
|
||||
# via datasets
|
||||
# via diffusers
|
||||
|
@ -197,7 +203,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105
|
|||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
|
@ -210,7 +216,7 @@ nvidia-cusparse-cu12==12.1.0.106
|
|||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.5.40
|
||||
nvidia-nvjitlink-cu12==12.6.20
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
|
@ -248,24 +254,24 @@ platformdirs==4.2.2
|
|||
# via wandb
|
||||
prodigyopt==1.0
|
||||
# via refiners
|
||||
protobuf==5.27.1
|
||||
protobuf==5.27.3
|
||||
# via wandb
|
||||
psutil==6.0.0
|
||||
# via neptune
|
||||
# via wandb
|
||||
pyarrow==16.1.0
|
||||
pyarrow==17.0.0
|
||||
# via datasets
|
||||
pyarrow-hotfix==0.6
|
||||
# via datasets
|
||||
pydantic==2.7.4
|
||||
pydantic==2.8.2
|
||||
# via refiners
|
||||
pydantic-core==2.18.4
|
||||
pydantic-core==2.20.1
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
# via mkdocs-material
|
||||
pyjwt==2.8.0
|
||||
pyjwt==2.9.0
|
||||
# via neptune
|
||||
pymdown-extensions==10.8.1
|
||||
pymdown-extensions==10.9
|
||||
# via mkdocs-material
|
||||
# via mkdocstrings
|
||||
python-dateutil==2.9.0.post0
|
||||
|
@ -296,7 +302,7 @@ pyyaml-env-tag==0.1
|
|||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
regex==2024.5.15
|
||||
regex==2024.7.24
|
||||
# via diffusers
|
||||
# via mkdocs-material
|
||||
# via transformers
|
||||
|
@ -318,10 +324,10 @@ rfc3339-validator==0.1.4
|
|||
# via jsonschema
|
||||
rfc3986-validator==0.1.1
|
||||
# via jsonschema
|
||||
rpds-py==0.18.1
|
||||
rpds-py==0.19.1
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
s3transfer==0.10.1
|
||||
s3transfer==0.10.2
|
||||
# via boto3
|
||||
safetensors==0.4.3
|
||||
# via diffusers
|
||||
|
@ -330,13 +336,13 @@ safetensors==0.4.3
|
|||
# via transformers
|
||||
segment-anything-hq==0.3
|
||||
# via refiners
|
||||
segment-anything-py==1.0
|
||||
segment-anything-py==1.0.1
|
||||
# via refiners
|
||||
sentry-sdk==2.6.0
|
||||
sentry-sdk==2.12.0
|
||||
# via wandb
|
||||
setproctitle==1.3.3
|
||||
# via wandb
|
||||
setuptools==70.1.0
|
||||
setuptools==72.1.0
|
||||
# via wandb
|
||||
simplejson==3.19.2
|
||||
# via bravado
|
||||
|
@ -350,26 +356,26 @@ six==1.16.0
|
|||
# via rfc3339-validator
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
swagger-spec-validator==3.0.3
|
||||
swagger-spec-validator==3.0.4
|
||||
# via bravado-core
|
||||
# via neptune
|
||||
sympy==1.12.1
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
timm==1.0.7
|
||||
timm==1.0.8
|
||||
# via refiners
|
||||
tokenizers==0.19.1
|
||||
# via transformers
|
||||
tomli==2.0.1
|
||||
# via black
|
||||
# via refiners
|
||||
torch==2.3.1
|
||||
torch==2.4.0
|
||||
# via bitsandbytes
|
||||
# via refiners
|
||||
# via segment-anything-hq
|
||||
# via segment-anything-py
|
||||
# via timm
|
||||
# via torchvision
|
||||
torchvision==0.18.1
|
||||
torchvision==0.19.0
|
||||
# via piq
|
||||
# via refiners
|
||||
# via segment-anything-hq
|
||||
|
@ -380,9 +386,9 @@ tqdm==4.66.4
|
|||
# via huggingface-hub
|
||||
# via refiners
|
||||
# via transformers
|
||||
transformers==4.41.2
|
||||
transformers==4.43.3
|
||||
# via refiners
|
||||
triton==2.3.1
|
||||
triton==3.0.0
|
||||
# via torch
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
|
@ -406,7 +412,7 @@ urllib3==2.2.2
|
|||
# via neptune
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
wandb==0.17.2
|
||||
wandb==0.17.5
|
||||
# via refiners
|
||||
watchdog==4.0.1
|
||||
# via mkdocs
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -20,6 +21,20 @@ class Args(argparse.Namespace):
|
|||
threshold: float
|
||||
|
||||
|
||||
class CLIPImageEncoderConfig(NamedTuple):
|
||||
architectures: list[str]
|
||||
num_channels: int
|
||||
hidden_size: int
|
||||
hidden_act: str
|
||||
image_size: int
|
||||
projection_dim: int
|
||||
patch_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
intermediate_size: int
|
||||
layer_norm_eps: float
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||
|
@ -28,34 +43,26 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
architecture: str = source.config.architectures[0] # type: ignore
|
||||
num_channels: int = source.config.num_channels # type: ignore
|
||||
embedding_dim: int = source.config.hidden_size # type: ignore
|
||||
image_size: int = source.config.image_size # type: ignore
|
||||
patch_size: int = source.config.patch_size # type: ignore
|
||||
output_dim: int = source.config.projection_dim # type: ignore
|
||||
num_layers: int = source.config.num_hidden_layers # type: ignore
|
||||
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
||||
feedforward_dim: int = source.config.intermediate_size # type: ignore
|
||||
activation: str = source.config.hidden_act # type: ignore
|
||||
layer_norm_eps: float = source.config.layer_norm_eps # type: ignore
|
||||
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
||||
|
||||
assert architecture == "CLIPVisionModelWithProjection", f"Unsupported architecture: {architecture}"
|
||||
assert num_channels == 3, f"Expected 3 input channels, got {num_channels}"
|
||||
assert activation == "gelu", f"Unsupported activation: {activation}"
|
||||
assert (
|
||||
config.architectures[0] == "CLIPVisionModelWithProjection"
|
||||
), f"Unsupported architecture: {config.architectures[0]}"
|
||||
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
|
||||
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
|
||||
|
||||
target = CLIPImageEncoder(
|
||||
image_size=image_size,
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=output_dim,
|
||||
patch_size=patch_size,
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
feedforward_dim=feedforward_dim,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
image_size=config.image_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
output_dim=config.projection_dim,
|
||||
patch_size=config.patch_size,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
feedforward_dim=config.intermediate_size,
|
||||
layer_norm_eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
x = torch.randn(1, 3, image_size, image_size)
|
||||
x = torch.randn(1, 3, config.image_size, config.image_size)
|
||||
|
||||
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from typing import NamedTuple, cast
|
||||
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
|
||||
|
@ -21,6 +21,18 @@ class Args(argparse.Namespace):
|
|||
verbose: bool
|
||||
|
||||
|
||||
class CLIPTextEncoderConfig(NamedTuple):
|
||||
architectures: list[str]
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
hidden_act: str
|
||||
layer_norm_eps: float
|
||||
projection_dim: int
|
||||
|
||||
|
||||
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
|
||||
|
@ -30,19 +42,17 @@ def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter
|
|||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
architecture: str = source.config.architectures[0] # type: ignore
|
||||
embedding_dim: int = source.config.hidden_size # type: ignore
|
||||
projection_dim: int = source.config.projection_dim # type: ignore
|
||||
num_layers: int = source.config.num_hidden_layers # type: ignore
|
||||
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
||||
feed_forward_dim: int = source.config.intermediate_size # type: ignore
|
||||
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
|
||||
config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
||||
architecture: str = config.architectures[0]
|
||||
embedding_dim: int = config.hidden_size
|
||||
projection_dim: int = config.projection_dim
|
||||
use_quick_gelu = config.hidden_act == "quick_gelu"
|
||||
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
|
||||
target = CLIPTextEncoder(
|
||||
embedding_dim=embedding_dim,
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
feedforward_dim=feed_forward_dim,
|
||||
embedding_dim=config.hidden_size,
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
feedforward_dim=config.intermediate_size,
|
||||
use_quick_gelu=use_quick_gelu,
|
||||
)
|
||||
if architecture == "CLIPTextModelWithProjection":
|
||||
|
|
|
@ -25,7 +25,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
|||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||
) # requires_grad=True by default
|
||||
self.old_weight = cast(Parameter, target.weight)
|
||||
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||
self.new_weight = p
|
||||
|
||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||
def lookup(self, x: Tensor) -> Tensor:
|
||||
|
@ -35,7 +35,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
|||
def add_embedding(self, embedding: Tensor) -> None:
|
||||
assert embedding.shape == (self.old_weight.shape[1],)
|
||||
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
|
||||
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||
self.new_weight = p
|
||||
|
||||
@property
|
||||
def num_embeddings(self) -> int:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from math import sqrt
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -123,8 +122,7 @@ class LayerScale(fl.WeightedModule):
|
|||
),
|
||||
)
|
||||
|
||||
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))
|
||||
self.register_parameter(name="weight", param=p)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x * self.weight
|
||||
|
|
|
@ -56,7 +56,7 @@ class Euler(Solver):
|
|||
def _generate_sigmas(self) -> Tensor:
|
||||
"""Generate the sigmas used by the solver."""
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
|
||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
|
|
@ -182,8 +182,7 @@ class MaskEncoder(fl.Chain):
|
|||
),
|
||||
)
|
||||
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
||||
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||
self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))
|
||||
self.register_parameter("no_mask_embedding", p)
|
||||
|
||||
def get_no_mask_dense_embedding(
|
||||
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
||||
|
|
|
@ -8,7 +8,10 @@ from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
|||
from prodigyopt import Prodigy # type: ignore
|
||||
from pydantic import BaseModel, BeforeValidator, ConfigDict
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.adamw import AdamW
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from refiners.training_utils.clock import ClockConfig
|
||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field
|
||||
|
|
|
@ -7,7 +7,6 @@ import torch
|
|||
from loguru import logger
|
||||
from torch import Tensor, device as Device, dtype as DType, nn
|
||||
from torch.autograd import backward
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import (
|
||||
CosineAnnealingLR,
|
||||
CosineAnnealingWarmRestarts,
|
||||
|
@ -21,6 +20,7 @@ from torch.optim.lr_scheduler import (
|
|||
ReduceLROnPlateau,
|
||||
StepLR,
|
||||
)
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.training_utils.callback import (
|
||||
|
@ -226,7 +226,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
||||
case LRSchedulerType.COSINE_ANNEALING_LR:
|
||||
lr_scheduler = CosineAnnealingLR(
|
||||
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
|
||||
optimizer=self.optimizer,
|
||||
T_max=scheduler_step_size,
|
||||
eta_min=config.eta_min, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
|
||||
lr_scheduler = cast(
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 233 KiB After Width: | Height: | Size: 234 KiB |
Binary file not shown.
Before Width: | Height: | Size: 315 KiB After Width: | Height: | Size: 315 KiB |
|
@ -12,7 +12,7 @@ from segment_anything_hq import ( # type: ignore
|
|||
)
|
||||
from segment_anything_hq.modeling.sam import Sam # type: ignore
|
||||
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
|
||||
from torch import optim
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
||||
from refiners.foundationals.segment_anything.hq_sam import (
|
||||
|
@ -125,13 +125,13 @@ def test_mask_decoder_tokens_extender() -> None:
|
|||
case _:
|
||||
raise ValueError
|
||||
|
||||
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
|
||||
optimizer = SGD(mask_decoder_tokens.parameters(), lr=10)
|
||||
optimizer.zero_grad()
|
||||
|
||||
ones = torch.ones_like(tokens_before)
|
||||
loss = torch.nn.functional.mse_loss(tokens_before, ones)
|
||||
loss.backward() # type: ignore
|
||||
optimizer.step()
|
||||
loss.backward() # pyright: ignore[reportUnknownMemberType]
|
||||
optimizer.step() # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
tokens_after = mask_decoder_tokens()
|
||||
|
||||
|
|
|
@ -379,7 +379,7 @@ def test_predictor_dense_mask(
|
|||
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
|
||||
|
||||
assert np.allclose(
|
||||
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1
|
||||
mask_input.cpu(), facebook_mask_input, atol=1e-1
|
||||
) # Lower doesn't pass, but it's close enough for logits
|
||||
|
||||
refiners_prompt = SAMPrompt(
|
||||
|
|
|
@ -2,12 +2,11 @@ import random
|
|||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import SGD
|
||||
from torch.optim.sgd import SGD
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import norm
|
||||
|
@ -188,8 +187,7 @@ def test_count_learnable_parameters_with_params() -> None:
|
|||
nn.Parameter(torch.randn(5), requires_grad=False),
|
||||
nn.Parameter(torch.randn(3, 3), requires_grad=True),
|
||||
]
|
||||
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
||||
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13
|
||||
assert count_learnable_parameters(params) == 13
|
||||
|
||||
|
||||
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
|
||||
|
|
Loading…
Reference in a new issue