mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 15:26:16 +00:00
fix typing issues coming from torch 2.4 version ; typing is not guaranteed for torch < 2.4
This commit is contained in:
parent
10fba3467a
commit
1de567590b
|
@ -10,7 +10,9 @@
|
||||||
# universal: false
|
# universal: false
|
||||||
|
|
||||||
-e file:.
|
-e file:.
|
||||||
aiohttp==3.9.5
|
aiohappyeyeballs==2.3.4
|
||||||
|
# via aiohttp
|
||||||
|
aiohttp==3.10.0
|
||||||
# via datasets
|
# via datasets
|
||||||
# via fsspec
|
# via fsspec
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
|
@ -27,20 +29,22 @@ attrs==23.2.0
|
||||||
# via referencing
|
# via referencing
|
||||||
babel==2.15.0
|
babel==2.15.0
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
bitsandbytes==0.43.1
|
backports-strenum==1.3.1
|
||||||
|
# via griffe
|
||||||
|
bitsandbytes==0.43.3
|
||||||
# via refiners
|
# via refiners
|
||||||
black==24.4.2
|
black==24.4.2
|
||||||
# via refiners
|
# via refiners
|
||||||
boto3==1.34.131
|
boto3==1.34.152
|
||||||
# via neptune
|
# via neptune
|
||||||
botocore==1.34.131
|
botocore==1.34.152
|
||||||
# via boto3
|
# via boto3
|
||||||
# via s3transfer
|
# via s3transfer
|
||||||
bravado==11.0.3
|
bravado==11.0.3
|
||||||
# via neptune
|
# via neptune
|
||||||
bravado-core==6.1.1
|
bravado-core==6.1.1
|
||||||
# via bravado
|
# via bravado
|
||||||
certifi==2024.6.2
|
certifi==2024.7.4
|
||||||
# via requests
|
# via requests
|
||||||
# via sentry-sdk
|
# via sentry-sdk
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
|
@ -56,7 +60,7 @@ colorama==0.4.6
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
datasets==2.20.0
|
datasets==2.20.0
|
||||||
# via refiners
|
# via refiners
|
||||||
diffusers==0.29.1
|
diffusers==0.29.2
|
||||||
# via refiners
|
# via refiners
|
||||||
dill==0.3.8
|
dill==0.3.8
|
||||||
# via datasets
|
# via datasets
|
||||||
|
@ -89,9 +93,9 @@ gitpython==3.1.43
|
||||||
# via neptune
|
# via neptune
|
||||||
# via refiners
|
# via refiners
|
||||||
# via wandb
|
# via wandb
|
||||||
griffe==0.47.0
|
griffe==0.48.0
|
||||||
# via mkdocstrings-python
|
# via mkdocstrings-python
|
||||||
huggingface-hub==0.23.4
|
huggingface-hub==0.24.5
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via timm
|
# via timm
|
||||||
|
@ -101,11 +105,13 @@ idna==3.7
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
# via requests
|
# via requests
|
||||||
# via yarl
|
# via yarl
|
||||||
importlib-metadata==7.2.1
|
importlib-metadata==8.2.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
importlib-resources==6.4.0
|
||||||
|
# via swagger-spec-validator
|
||||||
isoduration==20.11.0
|
isoduration==20.11.0
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
jaxtyping==0.2.30
|
jaxtyping==0.2.33
|
||||||
# via refiners
|
# via refiners
|
||||||
jinja2==3.1.4
|
jinja2==3.1.4
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
@ -119,7 +125,7 @@ jsonpointer==3.0.0
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
jsonref==1.1.0
|
jsonref==1.1.0
|
||||||
# via bravado-core
|
# via bravado-core
|
||||||
jsonschema==4.22.0
|
jsonschema==4.23.0
|
||||||
# via bravado-core
|
# via bravado-core
|
||||||
# via swagger-spec-validator
|
# via swagger-spec-validator
|
||||||
jsonschema-specifications==2023.12.1
|
jsonschema-specifications==2023.12.1
|
||||||
|
@ -151,14 +157,14 @@ mkdocs-get-deps==0.2.0
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
mkdocs-literate-nav==0.6.1
|
mkdocs-literate-nav==0.6.1
|
||||||
# via refiners
|
# via refiners
|
||||||
mkdocs-material==9.5.27
|
mkdocs-material==9.5.30
|
||||||
# via refiners
|
# via refiners
|
||||||
mkdocs-material-extensions==1.3.1
|
mkdocs-material-extensions==1.3.1
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
mkdocstrings==0.25.1
|
mkdocstrings==0.25.2
|
||||||
# via mkdocstrings-python
|
# via mkdocstrings-python
|
||||||
# via refiners
|
# via refiners
|
||||||
mkdocstrings-python==1.10.5
|
mkdocstrings-python==1.10.7
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
monotonic==1.6
|
monotonic==1.6
|
||||||
# via bravado
|
# via bravado
|
||||||
|
@ -178,7 +184,7 @@ neptune==1.10.4
|
||||||
# via refiners
|
# via refiners
|
||||||
networkx==3.3
|
networkx==3.3
|
||||||
# via torch
|
# via torch
|
||||||
numpy==2.0.0
|
numpy==2.0.1
|
||||||
# via bitsandbytes
|
# via bitsandbytes
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
@ -197,7 +203,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cuda-runtime-cu12==12.1.105
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cudnn-cu12==8.9.2.26
|
nvidia-cudnn-cu12==9.1.0.70
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-cufft-cu12==11.0.2.54
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
# via torch
|
# via torch
|
||||||
|
@ -210,7 +216,7 @@ nvidia-cusparse-cu12==12.1.0.106
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-nccl-cu12==2.20.5
|
nvidia-nccl-cu12==2.20.5
|
||||||
# via torch
|
# via torch
|
||||||
nvidia-nvjitlink-cu12==12.5.40
|
nvidia-nvjitlink-cu12==12.6.20
|
||||||
# via nvidia-cusolver-cu12
|
# via nvidia-cusolver-cu12
|
||||||
# via nvidia-cusparse-cu12
|
# via nvidia-cusparse-cu12
|
||||||
nvidia-nvtx-cu12==12.1.105
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
@ -248,24 +254,24 @@ platformdirs==4.2.2
|
||||||
# via wandb
|
# via wandb
|
||||||
prodigyopt==1.0
|
prodigyopt==1.0
|
||||||
# via refiners
|
# via refiners
|
||||||
protobuf==5.27.1
|
protobuf==5.27.3
|
||||||
# via wandb
|
# via wandb
|
||||||
psutil==6.0.0
|
psutil==6.0.0
|
||||||
# via neptune
|
# via neptune
|
||||||
# via wandb
|
# via wandb
|
||||||
pyarrow==16.1.0
|
pyarrow==17.0.0
|
||||||
# via datasets
|
# via datasets
|
||||||
pyarrow-hotfix==0.6
|
pyarrow-hotfix==0.6
|
||||||
# via datasets
|
# via datasets
|
||||||
pydantic==2.7.4
|
pydantic==2.8.2
|
||||||
# via refiners
|
# via refiners
|
||||||
pydantic-core==2.18.4
|
pydantic-core==2.20.1
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pygments==2.18.0
|
pygments==2.18.0
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
pyjwt==2.8.0
|
pyjwt==2.9.0
|
||||||
# via neptune
|
# via neptune
|
||||||
pymdown-extensions==10.8.1
|
pymdown-extensions==10.9
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
|
@ -296,7 +302,7 @@ pyyaml-env-tag==0.1
|
||||||
referencing==0.35.1
|
referencing==0.35.1
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
# via jsonschema-specifications
|
# via jsonschema-specifications
|
||||||
regex==2024.5.15
|
regex==2024.7.24
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
# via transformers
|
# via transformers
|
||||||
|
@ -318,10 +324,10 @@ rfc3339-validator==0.1.4
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
rfc3986-validator==0.1.1
|
rfc3986-validator==0.1.1
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
rpds-py==0.18.1
|
rpds-py==0.19.1
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
# via referencing
|
# via referencing
|
||||||
s3transfer==0.10.1
|
s3transfer==0.10.2
|
||||||
# via boto3
|
# via boto3
|
||||||
safetensors==0.4.3
|
safetensors==0.4.3
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
@ -330,13 +336,13 @@ safetensors==0.4.3
|
||||||
# via transformers
|
# via transformers
|
||||||
segment-anything-hq==0.3
|
segment-anything-hq==0.3
|
||||||
# via refiners
|
# via refiners
|
||||||
segment-anything-py==1.0
|
segment-anything-py==1.0.1
|
||||||
# via refiners
|
# via refiners
|
||||||
sentry-sdk==2.6.0
|
sentry-sdk==2.12.0
|
||||||
# via wandb
|
# via wandb
|
||||||
setproctitle==1.3.3
|
setproctitle==1.3.3
|
||||||
# via wandb
|
# via wandb
|
||||||
setuptools==70.1.0
|
setuptools==72.1.0
|
||||||
# via wandb
|
# via wandb
|
||||||
simplejson==3.19.2
|
simplejson==3.19.2
|
||||||
# via bravado
|
# via bravado
|
||||||
|
@ -350,26 +356,26 @@ six==1.16.0
|
||||||
# via rfc3339-validator
|
# via rfc3339-validator
|
||||||
smmap==5.0.1
|
smmap==5.0.1
|
||||||
# via gitdb
|
# via gitdb
|
||||||
swagger-spec-validator==3.0.3
|
swagger-spec-validator==3.0.4
|
||||||
# via bravado-core
|
# via bravado-core
|
||||||
# via neptune
|
# via neptune
|
||||||
sympy==1.12.1
|
sympy==1.13.1
|
||||||
# via torch
|
# via torch
|
||||||
timm==1.0.7
|
timm==1.0.8
|
||||||
# via refiners
|
# via refiners
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
# via transformers
|
# via transformers
|
||||||
tomli==2.0.1
|
tomli==2.0.1
|
||||||
# via black
|
# via black
|
||||||
# via refiners
|
# via refiners
|
||||||
torch==2.3.1
|
torch==2.4.0
|
||||||
# via bitsandbytes
|
# via bitsandbytes
|
||||||
# via refiners
|
# via refiners
|
||||||
# via segment-anything-hq
|
# via segment-anything-hq
|
||||||
# via segment-anything-py
|
# via segment-anything-py
|
||||||
# via timm
|
# via timm
|
||||||
# via torchvision
|
# via torchvision
|
||||||
torchvision==0.18.1
|
torchvision==0.19.0
|
||||||
# via piq
|
# via piq
|
||||||
# via refiners
|
# via refiners
|
||||||
# via segment-anything-hq
|
# via segment-anything-hq
|
||||||
|
@ -380,9 +386,9 @@ tqdm==4.66.4
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via refiners
|
# via refiners
|
||||||
# via transformers
|
# via transformers
|
||||||
transformers==4.41.2
|
transformers==4.43.3
|
||||||
# via refiners
|
# via refiners
|
||||||
triton==2.3.1
|
triton==3.0.0
|
||||||
# via torch
|
# via torch
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
# via jaxtyping
|
# via jaxtyping
|
||||||
|
@ -406,7 +412,7 @@ urllib3==2.2.2
|
||||||
# via neptune
|
# via neptune
|
||||||
# via requests
|
# via requests
|
||||||
# via sentry-sdk
|
# via sentry-sdk
|
||||||
wandb==0.17.2
|
wandb==0.17.5
|
||||||
# via refiners
|
# via refiners
|
||||||
watchdog==4.0.1
|
watchdog==4.0.1
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import NamedTuple, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -20,6 +21,20 @@ class Args(argparse.Namespace):
|
||||||
threshold: float
|
threshold: float
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoderConfig(NamedTuple):
|
||||||
|
architectures: list[str]
|
||||||
|
num_channels: int
|
||||||
|
hidden_size: int
|
||||||
|
hidden_act: str
|
||||||
|
image_size: int
|
||||||
|
projection_dim: int
|
||||||
|
patch_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
num_attention_heads: int
|
||||||
|
intermediate_size: int
|
||||||
|
layer_norm_eps: float
|
||||||
|
|
||||||
|
|
||||||
def setup_converter(args: Args) -> ModelConverter:
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||||
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||||
|
@ -28,34 +43,26 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
)
|
)
|
||||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||||
architecture: str = source.config.architectures[0] # type: ignore
|
config = cast(CLIPImageEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
||||||
num_channels: int = source.config.num_channels # type: ignore
|
|
||||||
embedding_dim: int = source.config.hidden_size # type: ignore
|
|
||||||
image_size: int = source.config.image_size # type: ignore
|
|
||||||
patch_size: int = source.config.patch_size # type: ignore
|
|
||||||
output_dim: int = source.config.projection_dim # type: ignore
|
|
||||||
num_layers: int = source.config.num_hidden_layers # type: ignore
|
|
||||||
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
|
||||||
feedforward_dim: int = source.config.intermediate_size # type: ignore
|
|
||||||
activation: str = source.config.hidden_act # type: ignore
|
|
||||||
layer_norm_eps: float = source.config.layer_norm_eps # type: ignore
|
|
||||||
|
|
||||||
assert architecture == "CLIPVisionModelWithProjection", f"Unsupported architecture: {architecture}"
|
assert (
|
||||||
assert num_channels == 3, f"Expected 3 input channels, got {num_channels}"
|
config.architectures[0] == "CLIPVisionModelWithProjection"
|
||||||
assert activation == "gelu", f"Unsupported activation: {activation}"
|
), f"Unsupported architecture: {config.architectures[0]}"
|
||||||
|
assert config.num_channels == 3, f"Expected 3 input channels, got {config.num_channels}"
|
||||||
|
assert config.hidden_act == "gelu", f"Unsupported activation: {config.hidden_act}"
|
||||||
|
|
||||||
target = CLIPImageEncoder(
|
target = CLIPImageEncoder(
|
||||||
image_size=image_size,
|
image_size=config.image_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=config.hidden_size,
|
||||||
output_dim=output_dim,
|
output_dim=config.projection_dim,
|
||||||
patch_size=patch_size,
|
patch_size=config.patch_size,
|
||||||
num_layers=num_layers,
|
num_layers=config.num_hidden_layers,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
feedforward_dim=feedforward_dim,
|
feedforward_dim=config.intermediate_size,
|
||||||
layer_norm_eps=layer_norm_eps,
|
layer_norm_eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = torch.randn(1, 3, image_size, image_size)
|
x = torch.randn(1, 3, config.image_size, config.image_size)
|
||||||
|
|
||||||
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
|
converter = ModelConverter(source_model=source, target_model=target, verbose=True)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import NamedTuple, cast
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection # type: ignore
|
||||||
|
@ -21,6 +21,18 @@ class Args(argparse.Namespace):
|
||||||
verbose: bool
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderConfig(NamedTuple):
|
||||||
|
architectures: list[str]
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
num_attention_heads: int
|
||||||
|
hidden_act: str
|
||||||
|
layer_norm_eps: float
|
||||||
|
projection_dim: int
|
||||||
|
|
||||||
|
|
||||||
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
|
def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter:
|
||||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||||
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
|
cls = CLIPTextModelWithProjection if with_projection else CLIPTextModel
|
||||||
|
@ -30,19 +42,17 @@ def setup_converter(args: Args, with_projection: bool = False) -> ModelConverter
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
)
|
)
|
||||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||||
architecture: str = source.config.architectures[0] # type: ignore
|
config = cast(CLIPTextEncoderConfig, source.config) # pyright: ignore[reportArgumentType, reportUnknownMemberType]
|
||||||
embedding_dim: int = source.config.hidden_size # type: ignore
|
architecture: str = config.architectures[0]
|
||||||
projection_dim: int = source.config.projection_dim # type: ignore
|
embedding_dim: int = config.hidden_size
|
||||||
num_layers: int = source.config.num_hidden_layers # type: ignore
|
projection_dim: int = config.projection_dim
|
||||||
num_attention_heads: int = source.config.num_attention_heads # type: ignore
|
use_quick_gelu = config.hidden_act == "quick_gelu"
|
||||||
feed_forward_dim: int = source.config.intermediate_size # type: ignore
|
|
||||||
use_quick_gelu: bool = source.config.hidden_act == "quick_gelu" # type: ignore
|
|
||||||
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
|
assert architecture in ("CLIPTextModel", "CLIPTextModelWithProjection"), f"Unsupported architecture: {architecture}"
|
||||||
target = CLIPTextEncoder(
|
target = CLIPTextEncoder(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=config.hidden_size,
|
||||||
num_layers=num_layers,
|
num_layers=config.num_hidden_layers,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=config.num_attention_heads,
|
||||||
feedforward_dim=feed_forward_dim,
|
feedforward_dim=config.intermediate_size,
|
||||||
use_quick_gelu=use_quick_gelu,
|
use_quick_gelu=use_quick_gelu,
|
||||||
)
|
)
|
||||||
if architecture == "CLIPTextModelWithProjection":
|
if architecture == "CLIPTextModelWithProjection":
|
||||||
|
|
|
@ -25,7 +25,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||||
) # requires_grad=True by default
|
) # requires_grad=True by default
|
||||||
self.old_weight = cast(Parameter, target.weight)
|
self.old_weight = cast(Parameter, target.weight)
|
||||||
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
self.new_weight = p
|
||||||
|
|
||||||
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||||
def lookup(self, x: Tensor) -> Tensor:
|
def lookup(self, x: Tensor) -> Tensor:
|
||||||
|
@ -35,7 +35,7 @@ class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
def add_embedding(self, embedding: Tensor) -> None:
|
def add_embedding(self, embedding: Tensor) -> None:
|
||||||
assert embedding.shape == (self.old_weight.shape[1],)
|
assert embedding.shape == (self.old_weight.shape[1],)
|
||||||
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
|
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
|
||||||
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
self.new_weight = p
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_embeddings(self) -> int:
|
def num_embeddings(self) -> int:
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -123,8 +122,7 @@ class LayerScale(fl.WeightedModule):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
self.register_parameter(name="weight", param=p)
|
||||||
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x * self.weight
|
return x * self.weight
|
||||||
|
|
|
@ -56,7 +56,7 @@ class Euler(Solver):
|
||||||
def _generate_sigmas(self) -> Tensor:
|
def _generate_sigmas(self) -> Tensor:
|
||||||
"""Generate the sigmas used by the solver."""
|
"""Generate the sigmas used by the solver."""
|
||||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
sigmas = torch.tensor(np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()))
|
||||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
|
@ -182,8 +182,7 @@ class MaskEncoder(fl.Chain):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
||||||
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
self.register_parameter("no_mask_embedding", p)
|
||||||
self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))
|
|
||||||
|
|
||||||
def get_no_mask_dense_embedding(
|
def get_no_mask_dense_embedding(
|
||||||
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
||||||
|
|
|
@ -8,7 +8,10 @@ from bitsandbytes.optim import AdamW8bit, Lion8bit # type: ignore
|
||||||
from prodigyopt import Prodigy # type: ignore
|
from prodigyopt import Prodigy # type: ignore
|
||||||
from pydantic import BaseModel, BeforeValidator, ConfigDict
|
from pydantic import BaseModel, BeforeValidator, ConfigDict
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import SGD, Adam, AdamW, Optimizer
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim.adamw import AdamW
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
from torch.optim.sgd import SGD
|
||||||
|
|
||||||
from refiners.training_utils.clock import ClockConfig
|
from refiners.training_utils.clock import ClockConfig
|
||||||
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field
|
from refiners.training_utils.common import Epoch, Iteration, Step, TimeValue, parse_number_unit_field
|
||||||
|
|
|
@ -7,7 +7,6 @@ import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch import Tensor, device as Device, dtype as DType, nn
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
from torch.autograd import backward
|
from torch.autograd import backward
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import (
|
from torch.optim.lr_scheduler import (
|
||||||
CosineAnnealingLR,
|
CosineAnnealingLR,
|
||||||
CosineAnnealingWarmRestarts,
|
CosineAnnealingWarmRestarts,
|
||||||
|
@ -21,6 +20,7 @@ from torch.optim.lr_scheduler import (
|
||||||
ReduceLROnPlateau,
|
ReduceLROnPlateau,
|
||||||
StepLR,
|
StepLR,
|
||||||
)
|
)
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.training_utils.callback import (
|
from refiners.training_utils.callback import (
|
||||||
|
@ -226,7 +226,9 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
lr_scheduler = ExponentialLR(optimizer=self.optimizer, gamma=config.gamma)
|
||||||
case LRSchedulerType.COSINE_ANNEALING_LR:
|
case LRSchedulerType.COSINE_ANNEALING_LR:
|
||||||
lr_scheduler = CosineAnnealingLR(
|
lr_scheduler = CosineAnnealingLR(
|
||||||
optimizer=self.optimizer, T_max=scheduler_step_size, eta_min=config.eta_min
|
optimizer=self.optimizer,
|
||||||
|
T_max=scheduler_step_size,
|
||||||
|
eta_min=config.eta_min, # pyright: ignore[reportArgumentType]
|
||||||
)
|
)
|
||||||
case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
|
case LRSchedulerType.REDUCE_LR_ON_PLATEAU:
|
||||||
lr_scheduler = cast(
|
lr_scheduler = cast(
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 233 KiB After Width: | Height: | Size: 234 KiB |
Binary file not shown.
Before Width: | Height: | Size: 315 KiB After Width: | Height: | Size: 315 KiB |
|
@ -12,7 +12,7 @@ from segment_anything_hq import ( # type: ignore
|
||||||
)
|
)
|
||||||
from segment_anything_hq.modeling.sam import Sam # type: ignore
|
from segment_anything_hq.modeling.sam import Sam # type: ignore
|
||||||
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
|
from tests.foundationals.segment_anything.utils import FacebookSAM, FacebookSAMPredictorHQ, SAMPrompt
|
||||||
from torch import optim
|
from torch.optim.sgd import SGD
|
||||||
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
||||||
from refiners.foundationals.segment_anything.hq_sam import (
|
from refiners.foundationals.segment_anything.hq_sam import (
|
||||||
|
@ -125,13 +125,13 @@ def test_mask_decoder_tokens_extender() -> None:
|
||||||
case _:
|
case _:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
optimizer = optim.SGD(mask_decoder_tokens.parameters(), lr=10)
|
optimizer = SGD(mask_decoder_tokens.parameters(), lr=10)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
ones = torch.ones_like(tokens_before)
|
ones = torch.ones_like(tokens_before)
|
||||||
loss = torch.nn.functional.mse_loss(tokens_before, ones)
|
loss = torch.nn.functional.mse_loss(tokens_before, ones)
|
||||||
loss.backward() # type: ignore
|
loss.backward() # pyright: ignore[reportUnknownMemberType]
|
||||||
optimizer.step()
|
optimizer.step() # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
tokens_after = mask_decoder_tokens()
|
tokens_after = mask_decoder_tokens()
|
||||||
|
|
||||||
|
|
|
@ -379,7 +379,7 @@ def test_predictor_dense_mask(
|
||||||
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
|
mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW
|
||||||
|
|
||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1
|
mask_input.cpu(), facebook_mask_input, atol=1e-1
|
||||||
) # Lower doesn't pass, but it's close enough for logits
|
) # Lower doesn't pass, but it's close enough for logits
|
||||||
|
|
||||||
refiners_prompt = SAMPrompt(
|
refiners_prompt = SAMPrompt(
|
||||||
|
|
|
@ -2,12 +2,11 @@ import random
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.optim import SGD
|
from torch.optim.sgd import SGD
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import norm
|
from refiners.fluxion.utils import norm
|
||||||
|
@ -188,8 +187,7 @@ def test_count_learnable_parameters_with_params() -> None:
|
||||||
nn.Parameter(torch.randn(5), requires_grad=False),
|
nn.Parameter(torch.randn(5), requires_grad=False),
|
||||||
nn.Parameter(torch.randn(3, 3), requires_grad=True),
|
nn.Parameter(torch.randn(3, 3), requires_grad=True),
|
||||||
]
|
]
|
||||||
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
|
assert count_learnable_parameters(params) == 13
|
||||||
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13
|
|
||||||
|
|
||||||
|
|
||||||
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
|
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
|
||||||
|
|
Loading…
Reference in a new issue