add support for dinov2 giant flavors

This commit is contained in:
Laurent 2024-04-11 12:14:18 +00:00 committed by Laureηt
parent 04e59bf3d9
commit 06ff2f0a5f
6 changed files with 132 additions and 41 deletions

View file

@ -18,6 +18,21 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
weights["cls_token"] = weights["cls_token"].squeeze(0)
weights["pos_embed"] = weights["pos_embed"].squeeze(0)
# rename "w12" to "fc1" and "w3" to "fc2", only for giant model
for key in list(weights.keys()):
if "w3" in key:
new_key = key.replace("w3", "fc2")
weights[new_key] = weights.pop(key)
elif "w12" in key:
# we swap w1 and w2 because of the difference between our GLU implementation and theirs
# see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34
# and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160
weight = weights.pop(key)
w1, w2 = weight.chunk(2, dim=0)
w21 = torch.cat([w2, w1], dim=0)
new_key = key.replace("w12", "fc1")
weights[new_key] = w21
rename_keys: list[tuple[str, str]] = [
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),

View file

@ -382,9 +382,11 @@ def download_dinov2():
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
]
download_files(urls, weights_folder)
@ -692,6 +694,12 @@ def convert_dinov2():
"tests/weights/dinov2_vitl14_pretrain.safetensors",
expected_hash="ddd4819f",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitg14_pretrain.pth",
"tests/weights/dinov2_vitg14_pretrain.safetensors",
expected_hash="880c61f5",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
@ -710,6 +718,12 @@ def convert_dinov2():
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
expected_hash="b1221702",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitg14_reg4_pretrain.pth",
"tests/weights/dinov2_vitg14_reg4_pretrain.safetensors",
expected_hash="639398eb",
)
def convert_control_lora_fooocus():

View file

@ -1,6 +1,8 @@
from .dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_giant,
DINOv2_giant_reg,
DINOv2_large,
DINOv2_large_reg,
DINOv2_small,
@ -12,6 +14,8 @@ from .vit import ViT
__all__ = [
"DINOv2_base",
"DINOv2_base_reg",
"DINOv2_giant",
"DINOv2_giant_reg",
"DINOv2_large",
"DINOv2_large_reg",
"DINOv2_small",

View file

@ -1,6 +1,7 @@
import torch
from PIL import Image
from refiners.fluxion.layers.activations import GLU, SiLU
from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.dinov2.vit import ViT
@ -130,22 +131,43 @@ class DINOv2_large(ViT):
)
# TODO: implement SwiGLU layer
# class DINOv2_giant2(ViT):
# def __init__(
# self,
# device: torch.device | str | None = None,
# dtype: torch.dtype | None = None,
# ) -> None:
# super().__init__(
# embedding_dim=1536,
# patch_size=14,
# image_size=518,
# num_layers=40,
# num_heads=24,
# device=device,
# dtype=dtype,
# )
class DINOv2_giant(ViT):
"""DINOv2 giant model.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
for more details.
Attributes:
embedding_dim (int): 1536
feedforward_dim (int): 4096
patch_size (int): 14
image_size (int): 518
num_layers (int): 40
num_heads (int): 24
"""
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""Initialize DINOv2 giant model.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1536,
feedforward_dim=4096,
patch_size=14,
image_size=518,
num_layers=40,
num_heads=24,
activation=GLU(SiLU()),
device=device,
dtype=dtype,
)
class DINOv2_small_reg(ViT):
@ -271,21 +293,44 @@ class DINOv2_large_reg(ViT):
)
# TODO: implement SwiGLU layer
# class DINOv2_giant2_reg(ViT):
# def __init__(
# self,
# device: torch.device | str | None = None,
# dtype: torch.dtype | None = None,
# ) -> None:
# super().__init__(
# embedding_dim=1536,
# patch_size=14,
# image_size=518,
# num_layers=40,
# num_heads=24,
# num_registers=4,
# interpolate_antialias=True,
# device=device,
# dtype=dtype,
# )
class DINOv2_giant_reg(ViT):
"""DINOv2 giant model with register.
See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)
Attributes:
embedding_dim (int): 1536
feedforward_dim (int): 4096
patch_size (int): 14
image_size (int): 518
num_layers (int): 40
num_heads (int): 24
num_registers (int): 4
interpolate_antialias (bool): True
"""
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""Initialize DINOv2 giant model with register.
Args:
device (torch.device | str | None): The PyTorch device to use.
dtype (torch.dtype | None): The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1536,
feedforward_dim=4096,
patch_size=14,
image_size=518,
num_layers=40,
num_heads=24,
num_registers=4,
interpolate_antialias=True,
activation=GLU(SiLU()),
device=device,
dtype=dtype,
)

View file

@ -137,21 +137,22 @@ class FeedForward(fl.Chain):
self,
embedding_dim: int,
feedforward_dim: int,
activation: Activation = fl.GeLU, # type: ignore
activation: Activation,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
pre_activation_dim = feedforward_dim * 2 if isinstance(activation, fl.GLU) else feedforward_dim
super().__init__(
fl.Linear(
in_features=embedding_dim,
out_features=feedforward_dim,
out_features=pre_activation_dim,
device=device,
dtype=dtype,
),
activation(),
activation,
fl.Linear(
in_features=feedforward_dim,
out_features=embedding_dim,
@ -200,6 +201,8 @@ class TransformerLayer(fl.Chain):
num_heads: int,
norm_eps: float,
mlp_ratio: int,
activation: Activation,
feedforward_dim: int | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
@ -207,6 +210,7 @@ class TransformerLayer(fl.Chain):
self.num_heads = num_heads
self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio
self.feedforward_dim = feedforward_dim if feedforward_dim is not None else embedding_dim * mlp_ratio
super().__init__(
fl.Residual(
@ -237,7 +241,8 @@ class TransformerLayer(fl.Chain):
),
FeedForward(
embedding_dim=embedding_dim,
feedforward_dim=embedding_dim * mlp_ratio,
feedforward_dim=self.feedforward_dim,
activation=activation,
device=device,
dtype=dtype,
),
@ -300,6 +305,8 @@ class ViT(fl.Chain):
norm_eps: float = 1e-6,
mlp_ratio: int = 4,
num_registers: int = 0,
activation: Activation = fl.GeLU(),
feedforward_dim: int | None = None,
interpolate_antialias: bool = False,
interpolate_mode: str = "bicubic",
device: torch.device | str | None = None,
@ -316,6 +323,8 @@ class ViT(fl.Chain):
norm_eps: The epsilon value for normalization.
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
num_registers: The number of registers.
activation: The activation function.
feedforward_dim: The dimension of the feedforward layer.
interpolate_antialias: Whether to use antialiasing for interpolation.
interpolate_mode: The interpolation mode.
device: The PyTorch device to use.
@ -330,6 +339,7 @@ class ViT(fl.Chain):
self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio
self.num_registers = num_registers
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Concatenate(
@ -370,6 +380,8 @@ class ViT(fl.Chain):
Transformer(
TransformerLayer(
embedding_dim=embedding_dim,
feedforward_dim=feedforward_dim,
activation=activation,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_eps=norm_eps,

View file

@ -9,6 +9,8 @@ from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_s
from refiners.foundationals.dinov2.dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_giant,
DINOv2_giant_reg,
DINOv2_large,
DINOv2_large_reg,
DINOv2_small,
@ -23,9 +25,8 @@ FLAVORS_MAP = {
"dinov2_vitb14_reg": DINOv2_base_reg,
"dinov2_vitl14": DINOv2_large,
"dinov2_vitl14_reg": DINOv2_large_reg,
# TODO: support giant flavors
# "dinov2_vitg14": DINOv2_giant,
# "dinov2_vitg14_reg": DINOv2_giant_reg,
"dinov2_vitg14": DINOv2_giant,
"dinov2_vitg14_reg": DINOv2_giant_reg,
}