add support for dinov2 giant flavors

This commit is contained in:
Laurent 2024-04-11 12:14:18 +00:00 committed by Laureηt
parent 04e59bf3d9
commit 06ff2f0a5f
6 changed files with 132 additions and 41 deletions

View file

@ -18,6 +18,21 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:
weights["cls_token"] = weights["cls_token"].squeeze(0) weights["cls_token"] = weights["cls_token"].squeeze(0)
weights["pos_embed"] = weights["pos_embed"].squeeze(0) weights["pos_embed"] = weights["pos_embed"].squeeze(0)
# rename "w12" to "fc1" and "w3" to "fc2", only for giant model
for key in list(weights.keys()):
if "w3" in key:
new_key = key.replace("w3", "fc2")
weights[new_key] = weights.pop(key)
elif "w12" in key:
# we swap w1 and w2 because of the difference between our GLU implementation and theirs
# see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34
# and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160
weight = weights.pop(key)
w1, w2 = weight.chunk(2, dim=0)
w21 = torch.cat([w2, w1], dim=0)
new_key = key.replace("w12", "fc1")
weights[new_key] = w21
rename_keys: list[tuple[str, str]] = [ rename_keys: list[tuple[str, str]] = [
("cls_token", "Concatenate.ClassToken.Parameter.weight"), ("cls_token", "Concatenate.ClassToken.Parameter.weight"),
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"), ("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),

View file

@ -382,9 +382,11 @@ def download_dinov2():
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth",
] ]
download_files(urls, weights_folder) download_files(urls, weights_folder)
@ -692,6 +694,12 @@ def convert_dinov2():
"tests/weights/dinov2_vitl14_pretrain.safetensors", "tests/weights/dinov2_vitl14_pretrain.safetensors",
expected_hash="ddd4819f", expected_hash="ddd4819f",
) )
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitg14_pretrain.pth",
"tests/weights/dinov2_vitg14_pretrain.safetensors",
expected_hash="880c61f5",
)
run_conversion_script( run_conversion_script(
"convert_dinov2.py", "convert_dinov2.py",
"tests/weights/dinov2_vits14_reg4_pretrain.pth", "tests/weights/dinov2_vits14_reg4_pretrain.pth",
@ -710,6 +718,12 @@ def convert_dinov2():
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors", "tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
expected_hash="b1221702", expected_hash="b1221702",
) )
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitg14_reg4_pretrain.pth",
"tests/weights/dinov2_vitg14_reg4_pretrain.safetensors",
expected_hash="639398eb",
)
def convert_control_lora_fooocus(): def convert_control_lora_fooocus():

View file

@ -1,6 +1,8 @@
from .dinov2 import ( from .dinov2 import (
DINOv2_base, DINOv2_base,
DINOv2_base_reg, DINOv2_base_reg,
DINOv2_giant,
DINOv2_giant_reg,
DINOv2_large, DINOv2_large,
DINOv2_large_reg, DINOv2_large_reg,
DINOv2_small, DINOv2_small,
@ -12,6 +14,8 @@ from .vit import ViT
__all__ = [ __all__ = [
"DINOv2_base", "DINOv2_base",
"DINOv2_base_reg", "DINOv2_base_reg",
"DINOv2_giant",
"DINOv2_giant_reg",
"DINOv2_large", "DINOv2_large",
"DINOv2_large_reg", "DINOv2_large_reg",
"DINOv2_small", "DINOv2_small",

View file

@ -1,6 +1,7 @@
import torch import torch
from PIL import Image from PIL import Image
from refiners.fluxion.layers.activations import GLU, SiLU
from refiners.fluxion.utils import image_to_tensor, normalize from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.dinov2.vit import ViT from refiners.foundationals.dinov2.vit import ViT
@ -130,22 +131,43 @@ class DINOv2_large(ViT):
) )
# TODO: implement SwiGLU layer class DINOv2_giant(ViT):
# class DINOv2_giant2(ViT): """DINOv2 giant model.
# def __init__(
# self, See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
# device: torch.device | str | None = None, for more details.
# dtype: torch.dtype | None = None,
# ) -> None: Attributes:
# super().__init__( embedding_dim (int): 1536
# embedding_dim=1536, feedforward_dim (int): 4096
# patch_size=14, patch_size (int): 14
# image_size=518, image_size (int): 518
# num_layers=40, num_layers (int): 40
# num_heads=24, num_heads (int): 24
# device=device, """
# dtype=dtype,
# ) def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""Initialize DINOv2 giant model.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1536,
feedforward_dim=4096,
patch_size=14,
image_size=518,
num_layers=40,
num_heads=24,
activation=GLU(SiLU()),
device=device,
dtype=dtype,
)
class DINOv2_small_reg(ViT): class DINOv2_small_reg(ViT):
@ -271,21 +293,44 @@ class DINOv2_large_reg(ViT):
) )
# TODO: implement SwiGLU layer class DINOv2_giant_reg(ViT):
# class DINOv2_giant2_reg(ViT): """DINOv2 giant model with register.
# def __init__(
# self, See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
# device: torch.device | str | None = None, and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588)
# dtype: torch.dtype | None = None,
# ) -> None: Attributes:
# super().__init__( embedding_dim (int): 1536
# embedding_dim=1536, feedforward_dim (int): 4096
# patch_size=14, patch_size (int): 14
# image_size=518, image_size (int): 518
# num_layers=40, num_layers (int): 40
# num_heads=24, num_heads (int): 24
# num_registers=4, num_registers (int): 4
# interpolate_antialias=True, interpolate_antialias (bool): True
# device=device, """
# dtype=dtype,
# ) def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
"""Initialize DINOv2 giant model with register.
Args:
device (torch.device | str | None): The PyTorch device to use.
dtype (torch.dtype | None): The PyTorch data type to use.
"""
super().__init__(
embedding_dim=1536,
feedforward_dim=4096,
patch_size=14,
image_size=518,
num_layers=40,
num_heads=24,
num_registers=4,
interpolate_antialias=True,
activation=GLU(SiLU()),
device=device,
dtype=dtype,
)

View file

@ -137,21 +137,22 @@ class FeedForward(fl.Chain):
self, self,
embedding_dim: int, embedding_dim: int,
feedforward_dim: int, feedforward_dim: int,
activation: Activation = fl.GeLU, # type: ignore activation: Activation,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim self.feedforward_dim = feedforward_dim
pre_activation_dim = feedforward_dim * 2 if isinstance(activation, fl.GLU) else feedforward_dim
super().__init__( super().__init__(
fl.Linear( fl.Linear(
in_features=embedding_dim, in_features=embedding_dim,
out_features=feedforward_dim, out_features=pre_activation_dim,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
activation(), activation,
fl.Linear( fl.Linear(
in_features=feedforward_dim, in_features=feedforward_dim,
out_features=embedding_dim, out_features=embedding_dim,
@ -200,6 +201,8 @@ class TransformerLayer(fl.Chain):
num_heads: int, num_heads: int,
norm_eps: float, norm_eps: float,
mlp_ratio: int, mlp_ratio: int,
activation: Activation,
feedforward_dim: int | None = None,
device: torch.device | str | None = None, device: torch.device | str | None = None,
dtype: torch.dtype | None = None, dtype: torch.dtype | None = None,
) -> None: ) -> None:
@ -207,6 +210,7 @@ class TransformerLayer(fl.Chain):
self.num_heads = num_heads self.num_heads = num_heads
self.norm_eps = norm_eps self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.feedforward_dim = feedforward_dim if feedforward_dim is not None else embedding_dim * mlp_ratio
super().__init__( super().__init__(
fl.Residual( fl.Residual(
@ -237,7 +241,8 @@ class TransformerLayer(fl.Chain):
), ),
FeedForward( FeedForward(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
feedforward_dim=embedding_dim * mlp_ratio, feedforward_dim=self.feedforward_dim,
activation=activation,
device=device, device=device,
dtype=dtype, dtype=dtype,
), ),
@ -300,6 +305,8 @@ class ViT(fl.Chain):
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
mlp_ratio: int = 4, mlp_ratio: int = 4,
num_registers: int = 0, num_registers: int = 0,
activation: Activation = fl.GeLU(),
feedforward_dim: int | None = None,
interpolate_antialias: bool = False, interpolate_antialias: bool = False,
interpolate_mode: str = "bicubic", interpolate_mode: str = "bicubic",
device: torch.device | str | None = None, device: torch.device | str | None = None,
@ -316,6 +323,8 @@ class ViT(fl.Chain):
norm_eps: The epsilon value for normalization. norm_eps: The epsilon value for normalization.
mlp_ratio: The ratio for the multi-layer perceptron (MLP). mlp_ratio: The ratio for the multi-layer perceptron (MLP).
num_registers: The number of registers. num_registers: The number of registers.
activation: The activation function.
feedforward_dim: The dimension of the feedforward layer.
interpolate_antialias: Whether to use antialiasing for interpolation. interpolate_antialias: Whether to use antialiasing for interpolation.
interpolate_mode: The interpolation mode. interpolate_mode: The interpolation mode.
device: The PyTorch device to use. device: The PyTorch device to use.
@ -330,6 +339,7 @@ class ViT(fl.Chain):
self.norm_eps = norm_eps self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.num_registers = num_registers self.num_registers = num_registers
self.feedforward_dim = feedforward_dim
super().__init__( super().__init__(
fl.Concatenate( fl.Concatenate(
@ -370,6 +380,8 @@ class ViT(fl.Chain):
Transformer( Transformer(
TransformerLayer( TransformerLayer(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
feedforward_dim=feedforward_dim,
activation=activation,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
norm_eps=norm_eps, norm_eps=norm_eps,

View file

@ -9,6 +9,8 @@ from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_s
from refiners.foundationals.dinov2.dinov2 import ( from refiners.foundationals.dinov2.dinov2 import (
DINOv2_base, DINOv2_base,
DINOv2_base_reg, DINOv2_base_reg,
DINOv2_giant,
DINOv2_giant_reg,
DINOv2_large, DINOv2_large,
DINOv2_large_reg, DINOv2_large_reg,
DINOv2_small, DINOv2_small,
@ -23,9 +25,8 @@ FLAVORS_MAP = {
"dinov2_vitb14_reg": DINOv2_base_reg, "dinov2_vitb14_reg": DINOv2_base_reg,
"dinov2_vitl14": DINOv2_large, "dinov2_vitl14": DINOv2_large,
"dinov2_vitl14_reg": DINOv2_large_reg, "dinov2_vitl14_reg": DINOv2_large_reg,
# TODO: support giant flavors "dinov2_vitg14": DINOv2_giant,
# "dinov2_vitg14": DINOv2_giant, "dinov2_vitg14_reg": DINOv2_giant_reg,
# "dinov2_vitg14_reg": DINOv2_giant_reg,
} }