mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add support for dinov2 giant flavors
This commit is contained in:
parent
04e59bf3d9
commit
06ff2f0a5f
|
@ -18,6 +18,21 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
|
||||||
weights["cls_token"] = weights["cls_token"].squeeze(0)
|
weights["cls_token"] = weights["cls_token"].squeeze(0)
|
||||||
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
|
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
|
||||||
|
|
||||||
|
# rename "w12" to "fc1" and "w3" to "fc2", only for giant model
|
||||||
|
for key in list(weights.keys()):
|
||||||
|
if "w3" in key:
|
||||||
|
new_key = key.replace("w3", "fc2")
|
||||||
|
weights[new_key] = weights.pop(key)
|
||||||
|
elif "w12" in key:
|
||||||
|
# we swap w1 and w2 because of the difference between our GLU implementation and theirs
|
||||||
|
# see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34
|
||||||
|
# and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160
|
||||||
|
weight = weights.pop(key)
|
||||||
|
w1, w2 = weight.chunk(2, dim=0)
|
||||||
|
w21 = torch.cat([w2, w1], dim=0)
|
||||||
|
new_key = key.replace("w12", "fc1")
|
||||||
|
weights[new_key] = w21
|
||||||
|
|
||||||
rename_keys: list[tuple[str, str]] = [
|
rename_keys: list[tuple[str, str]] = [
|
||||||
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
|
||||||
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
|
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
|
||||||
|
|
|
@ -382,9 +382,11 @@ def download_dinov2():
|
||||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
|
||||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
|
||||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
|
||||||
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth",
|
||||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
|
||||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
|
||||||
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
|
||||||
|
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
|
||||||
]
|
]
|
||||||
download_files(urls, weights_folder)
|
download_files(urls, weights_folder)
|
||||||
|
|
||||||
|
@ -692,6 +694,12 @@ def convert_dinov2():
|
||||||
"tests/weights/dinov2_vitl14_pretrain.safetensors",
|
"tests/weights/dinov2_vitl14_pretrain.safetensors",
|
||||||
expected_hash="ddd4819f",
|
expected_hash="ddd4819f",
|
||||||
)
|
)
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_dinov2.py",
|
||||||
|
"tests/weights/dinov2_vitg14_pretrain.pth",
|
||||||
|
"tests/weights/dinov2_vitg14_pretrain.safetensors",
|
||||||
|
expected_hash="880c61f5",
|
||||||
|
)
|
||||||
run_conversion_script(
|
run_conversion_script(
|
||||||
"convert_dinov2.py",
|
"convert_dinov2.py",
|
||||||
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
|
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
|
||||||
|
@ -710,6 +718,12 @@ def convert_dinov2():
|
||||||
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
|
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
|
||||||
expected_hash="b1221702",
|
expected_hash="b1221702",
|
||||||
)
|
)
|
||||||
|
run_conversion_script(
|
||||||
|
"convert_dinov2.py",
|
||||||
|
"tests/weights/dinov2_vitg14_reg4_pretrain.pth",
|
||||||
|
"tests/weights/dinov2_vitg14_reg4_pretrain.safetensors",
|
||||||
|
expected_hash="639398eb",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_control_lora_fooocus():
|
def convert_control_lora_fooocus():
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from .dinov2 import (
|
from .dinov2 import (
|
||||||
DINOv2_base,
|
DINOv2_base,
|
||||||
DINOv2_base_reg,
|
DINOv2_base_reg,
|
||||||
|
DINOv2_giant,
|
||||||
|
DINOv2_giant_reg,
|
||||||
DINOv2_large,
|
DINOv2_large,
|
||||||
DINOv2_large_reg,
|
DINOv2_large_reg,
|
||||||
DINOv2_small,
|
DINOv2_small,
|
||||||
|
@ -12,6 +14,8 @@ from .vit import ViT
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DINOv2_base",
|
"DINOv2_base",
|
||||||
"DINOv2_base_reg",
|
"DINOv2_base_reg",
|
||||||
|
"DINOv2_giant",
|
||||||
|
"DINOv2_giant_reg",
|
||||||
"DINOv2_large",
|
"DINOv2_large",
|
||||||
"DINOv2_large_reg",
|
"DINOv2_large_reg",
|
||||||
"DINOv2_small",
|
"DINOv2_small",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.layers.activations import GLU, SiLU
|
||||||
from refiners.fluxion.utils import image_to_tensor, normalize
|
from refiners.fluxion.utils import image_to_tensor, normalize
|
||||||
from refiners.foundationals.dinov2.vit import ViT
|
from refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
|
@ -130,22 +131,43 @@ class DINOv2_large(ViT):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: implement SwiGLU layer
|
class DINOv2_giant(ViT):
|
||||||
# class DINOv2_giant2(ViT):
|
"""DINOv2 giant model.
|
||||||
# def __init__(
|
|
||||||
# self,
|
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
|
||||||
# device: torch.device | str | None = None,
|
for more details.
|
||||||
# dtype: torch.dtype | None = None,
|
|
||||||
# ) -> None:
|
Attributes:
|
||||||
# super().__init__(
|
embedding_dim (int): 1536
|
||||||
# embedding_dim=1536,
|
feedforward_dim (int): 4096
|
||||||
# patch_size=14,
|
patch_size (int): 14
|
||||||
# image_size=518,
|
image_size (int): 518
|
||||||
# num_layers=40,
|
num_layers (int): 40
|
||||||
# num_heads=24,
|
num_heads (int): 24
|
||||||
# device=device,
|
"""
|
||||||
# dtype=dtype,
|
|
||||||
# )
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize DINOv2 giant model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1536,
|
||||||
|
feedforward_dim=4096,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=40,
|
||||||
|
num_heads=24,
|
||||||
|
activation=GLU(SiLU()),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DINOv2_small_reg(ViT):
|
class DINOv2_small_reg(ViT):
|
||||||
|
@ -271,21 +293,44 @@ class DINOv2_large_reg(ViT):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: implement SwiGLU layer
|
class DINOv2_giant_reg(ViT):
|
||||||
# class DINOv2_giant2_reg(ViT):
|
"""DINOv2 giant model with register.
|
||||||
# def __init__(
|
|
||||||
# self,
|
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
|
||||||
# device: torch.device | str | None = None,
|
and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)
|
||||||
# dtype: torch.dtype | None = None,
|
|
||||||
# ) -> None:
|
Attributes:
|
||||||
# super().__init__(
|
embedding_dim (int): 1536
|
||||||
# embedding_dim=1536,
|
feedforward_dim (int): 4096
|
||||||
# patch_size=14,
|
patch_size (int): 14
|
||||||
# image_size=518,
|
image_size (int): 518
|
||||||
# num_layers=40,
|
num_layers (int): 40
|
||||||
# num_heads=24,
|
num_heads (int): 24
|
||||||
# num_registers=4,
|
num_registers (int): 4
|
||||||
# interpolate_antialias=True,
|
interpolate_antialias (bool): True
|
||||||
# device=device,
|
"""
|
||||||
# dtype=dtype,
|
|
||||||
# )
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize DINOv2 giant model with register.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device | str | None): The PyTorch device to use.
|
||||||
|
dtype (torch.dtype | None): The PyTorch data type to use.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1536,
|
||||||
|
feedforward_dim=4096,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=40,
|
||||||
|
num_heads=24,
|
||||||
|
num_registers=4,
|
||||||
|
interpolate_antialias=True,
|
||||||
|
activation=GLU(SiLU()),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
|
@ -137,21 +137,22 @@ class FeedForward(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
activation: Activation = fl.GeLU, # type: ignore
|
activation: Activation,
|
||||||
device: torch.device | str | None = None,
|
device: torch.device | str | None = None,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
|
pre_activation_dim = feedforward_dim * 2 if isinstance(activation, fl.GLU) else feedforward_dim
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
in_features=embedding_dim,
|
in_features=embedding_dim,
|
||||||
out_features=feedforward_dim,
|
out_features=pre_activation_dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
activation(),
|
activation,
|
||||||
fl.Linear(
|
fl.Linear(
|
||||||
in_features=feedforward_dim,
|
in_features=feedforward_dim,
|
||||||
out_features=embedding_dim,
|
out_features=embedding_dim,
|
||||||
|
@ -200,6 +201,8 @@ class TransformerLayer(fl.Chain):
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
mlp_ratio: int,
|
mlp_ratio: int,
|
||||||
|
activation: Activation,
|
||||||
|
feedforward_dim: int | None = None,
|
||||||
device: torch.device | str | None = None,
|
device: torch.device | str | None = None,
|
||||||
dtype: torch.dtype | None = None,
|
dtype: torch.dtype | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -207,6 +210,7 @@ class TransformerLayer(fl.Chain):
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.feedforward_dim = feedforward_dim if feedforward_dim is not None else embedding_dim * mlp_ratio
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Residual(
|
fl.Residual(
|
||||||
|
@ -237,7 +241,8 @@ class TransformerLayer(fl.Chain):
|
||||||
),
|
),
|
||||||
FeedForward(
|
FeedForward(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
feedforward_dim=embedding_dim * mlp_ratio,
|
feedforward_dim=self.feedforward_dim,
|
||||||
|
activation=activation,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
|
@ -300,6 +305,8 @@ class ViT(fl.Chain):
|
||||||
norm_eps: float = 1e-6,
|
norm_eps: float = 1e-6,
|
||||||
mlp_ratio: int = 4,
|
mlp_ratio: int = 4,
|
||||||
num_registers: int = 0,
|
num_registers: int = 0,
|
||||||
|
activation: Activation = fl.GeLU(),
|
||||||
|
feedforward_dim: int | None = None,
|
||||||
interpolate_antialias: bool = False,
|
interpolate_antialias: bool = False,
|
||||||
interpolate_mode: str = "bicubic",
|
interpolate_mode: str = "bicubic",
|
||||||
device: torch.device | str | None = None,
|
device: torch.device | str | None = None,
|
||||||
|
@ -316,6 +323,8 @@ class ViT(fl.Chain):
|
||||||
norm_eps: The epsilon value for normalization.
|
norm_eps: The epsilon value for normalization.
|
||||||
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
|
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
|
||||||
num_registers: The number of registers.
|
num_registers: The number of registers.
|
||||||
|
activation: The activation function.
|
||||||
|
feedforward_dim: The dimension of the feedforward layer.
|
||||||
interpolate_antialias: Whether to use antialiasing for interpolation.
|
interpolate_antialias: Whether to use antialiasing for interpolation.
|
||||||
interpolate_mode: The interpolation mode.
|
interpolate_mode: The interpolation mode.
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
|
@ -330,6 +339,7 @@ class ViT(fl.Chain):
|
||||||
self.norm_eps = norm_eps
|
self.norm_eps = norm_eps
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
self.num_registers = num_registers
|
self.num_registers = num_registers
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Concatenate(
|
fl.Concatenate(
|
||||||
|
@ -370,6 +380,8 @@ class ViT(fl.Chain):
|
||||||
Transformer(
|
Transformer(
|
||||||
TransformerLayer(
|
TransformerLayer(
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
activation=activation,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
|
|
|
@ -9,6 +9,8 @@ from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_s
|
||||||
from refiners.foundationals.dinov2.dinov2 import (
|
from refiners.foundationals.dinov2.dinov2 import (
|
||||||
DINOv2_base,
|
DINOv2_base,
|
||||||
DINOv2_base_reg,
|
DINOv2_base_reg,
|
||||||
|
DINOv2_giant,
|
||||||
|
DINOv2_giant_reg,
|
||||||
DINOv2_large,
|
DINOv2_large,
|
||||||
DINOv2_large_reg,
|
DINOv2_large_reg,
|
||||||
DINOv2_small,
|
DINOv2_small,
|
||||||
|
@ -23,9 +25,8 @@ FLAVORS_MAP = {
|
||||||
"dinov2_vitb14_reg": DINOv2_base_reg,
|
"dinov2_vitb14_reg": DINOv2_base_reg,
|
||||||
"dinov2_vitl14": DINOv2_large,
|
"dinov2_vitl14": DINOv2_large,
|
||||||
"dinov2_vitl14_reg": DINOv2_large_reg,
|
"dinov2_vitl14_reg": DINOv2_large_reg,
|
||||||
# TODO: support giant flavors
|
"dinov2_vitg14": DINOv2_giant,
|
||||||
# "dinov2_vitg14": DINOv2_giant,
|
"dinov2_vitg14_reg": DINOv2_giant_reg,
|
||||||
# "dinov2_vitg14_reg": DINOv2_giant_reg,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue