diff --git a/scripts/conversion/convert_dinov2.py b/scripts/conversion/convert_dinov2.py index 9e907b4..ec4b204 100644 --- a/scripts/conversion/convert_dinov2.py +++ b/scripts/conversion/convert_dinov2.py @@ -18,6 +18,21 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None: weights["cls_token"] = weights["cls_token"].squeeze(0) weights["pos_embed"] = weights["pos_embed"].squeeze(0) + # rename "w12" to "fc1" and "w3" to "fc2", only for giant model + for key in list(weights.keys()): + if "w3" in key: + new_key = key.replace("w3", "fc2") + weights[new_key] = weights.pop(key) + elif "w12" in key: + # we swap w1 and w2 because of the difference between our GLU implementation and theirs + # see https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/layers/swiglu_ffn.py#L31-L34 + # and https://github.com/finegrain-ai/refiners/blob/a2ee70578361e4d84a65a8708564480a9b0ec67e/src/refiners/fluxion/layers/activations.py#L158-L160 + weight = weights.pop(key) + w1, w2 = weight.chunk(2, dim=0) + w21 = torch.cat([w2, w1], dim=0) + new_key = key.replace("w12", "fc1") + weights[new_key] = w21 + rename_keys: list[tuple[str, str]] = [ ("cls_token", "Concatenate.ClassToken.Parameter.weight"), ("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"), diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index eb759d6..8e70c4a 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -382,9 +382,11 @@ def download_dinov2(): "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", + "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", + "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth", ] download_files(urls, weights_folder) @@ -692,6 +694,12 @@ def convert_dinov2(): "tests/weights/dinov2_vitl14_pretrain.safetensors", expected_hash="ddd4819f", ) + run_conversion_script( + "convert_dinov2.py", + "tests/weights/dinov2_vitg14_pretrain.pth", + "tests/weights/dinov2_vitg14_pretrain.safetensors", + expected_hash="880c61f5", + ) run_conversion_script( "convert_dinov2.py", "tests/weights/dinov2_vits14_reg4_pretrain.pth", @@ -710,6 +718,12 @@ def convert_dinov2(): "tests/weights/dinov2_vitl14_reg4_pretrain.safetensors", expected_hash="b1221702", ) + run_conversion_script( + "convert_dinov2.py", + "tests/weights/dinov2_vitg14_reg4_pretrain.pth", + "tests/weights/dinov2_vitg14_reg4_pretrain.safetensors", + expected_hash="639398eb", + ) def convert_control_lora_fooocus(): diff --git a/src/refiners/foundationals/dinov2/__init__.py b/src/refiners/foundationals/dinov2/__init__.py index 5ef56d9..a5fdb1f 100644 --- a/src/refiners/foundationals/dinov2/__init__.py +++ b/src/refiners/foundationals/dinov2/__init__.py @@ -1,6 +1,8 @@ from .dinov2 import ( DINOv2_base, DINOv2_base_reg, + DINOv2_giant, + DINOv2_giant_reg, DINOv2_large, DINOv2_large_reg, DINOv2_small, @@ -12,6 +14,8 @@ from .vit import ViT __all__ = [ "DINOv2_base", "DINOv2_base_reg", + "DINOv2_giant", + "DINOv2_giant_reg", "DINOv2_large", "DINOv2_large_reg", "DINOv2_small", diff --git a/src/refiners/foundationals/dinov2/dinov2.py b/src/refiners/foundationals/dinov2/dinov2.py index a984007..373315d 100644 --- a/src/refiners/foundationals/dinov2/dinov2.py +++ b/src/refiners/foundationals/dinov2/dinov2.py @@ -1,6 +1,7 @@ import torch from PIL import Image +from refiners.fluxion.layers.activations import GLU, SiLU from refiners.fluxion.utils import image_to_tensor, normalize from refiners.foundationals.dinov2.vit import ViT @@ -130,22 +131,43 @@ class DINOv2_large(ViT): ) -# TODO: implement SwiGLU layer -# class DINOv2_giant2(ViT): -# def __init__( -# self, -# device: torch.device | str | None = None, -# dtype: torch.dtype | None = None, -# ) -> None: -# super().__init__( -# embedding_dim=1536, -# patch_size=14, -# image_size=518, -# num_layers=40, -# num_heads=24, -# device=device, -# dtype=dtype, -# ) +class DINOv2_giant(ViT): + """DINOv2 giant model. + + See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) + for more details. + + Attributes: + embedding_dim (int): 1536 + feedforward_dim (int): 4096 + patch_size (int): 14 + image_size (int): 518 + num_layers (int): 40 + num_heads (int): 24 + """ + + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + """Initialize DINOv2 giant model. + + Args: + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ + super().__init__( + embedding_dim=1536, + feedforward_dim=4096, + patch_size=14, + image_size=518, + num_layers=40, + num_heads=24, + activation=GLU(SiLU()), + device=device, + dtype=dtype, + ) class DINOv2_small_reg(ViT): @@ -271,21 +293,44 @@ class DINOv2_large_reg(ViT): ) -# TODO: implement SwiGLU layer -# class DINOv2_giant2_reg(ViT): -# def __init__( -# self, -# device: torch.device | str | None = None, -# dtype: torch.dtype | None = None, -# ) -> None: -# super().__init__( -# embedding_dim=1536, -# patch_size=14, -# image_size=518, -# num_layers=40, -# num_heads=24, -# num_registers=4, -# interpolate_antialias=True, -# device=device, -# dtype=dtype, -# ) +class DINOv2_giant_reg(ViT): + """DINOv2 giant model with register. + + See [[arXiv:2304.07193] DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193) + and [[arXiv:2309.16588] Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588) + + Attributes: + embedding_dim (int): 1536 + feedforward_dim (int): 4096 + patch_size (int): 14 + image_size (int): 518 + num_layers (int): 40 + num_heads (int): 24 + num_registers (int): 4 + interpolate_antialias (bool): True + """ + + def __init__( + self, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> None: + """Initialize DINOv2 giant model with register. + + Args: + device (torch.device | str | None): The PyTorch device to use. + dtype (torch.dtype | None): The PyTorch data type to use. + """ + super().__init__( + embedding_dim=1536, + feedforward_dim=4096, + patch_size=14, + image_size=518, + num_layers=40, + num_heads=24, + num_registers=4, + interpolate_antialias=True, + activation=GLU(SiLU()), + device=device, + dtype=dtype, + ) diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index 4808d6f..7a4e489 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -137,21 +137,22 @@ class FeedForward(fl.Chain): self, embedding_dim: int, feedforward_dim: int, - activation: Activation = fl.GeLU, # type: ignore + activation: Activation, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: self.embedding_dim = embedding_dim self.feedforward_dim = feedforward_dim + pre_activation_dim = feedforward_dim * 2 if isinstance(activation, fl.GLU) else feedforward_dim super().__init__( fl.Linear( in_features=embedding_dim, - out_features=feedforward_dim, + out_features=pre_activation_dim, device=device, dtype=dtype, ), - activation(), + activation, fl.Linear( in_features=feedforward_dim, out_features=embedding_dim, @@ -200,6 +201,8 @@ class TransformerLayer(fl.Chain): num_heads: int, norm_eps: float, mlp_ratio: int, + activation: Activation, + feedforward_dim: int | None = None, device: torch.device | str | None = None, dtype: torch.dtype | None = None, ) -> None: @@ -207,6 +210,7 @@ class TransformerLayer(fl.Chain): self.num_heads = num_heads self.norm_eps = norm_eps self.mlp_ratio = mlp_ratio + self.feedforward_dim = feedforward_dim if feedforward_dim is not None else embedding_dim * mlp_ratio super().__init__( fl.Residual( @@ -237,7 +241,8 @@ class TransformerLayer(fl.Chain): ), FeedForward( embedding_dim=embedding_dim, - feedforward_dim=embedding_dim * mlp_ratio, + feedforward_dim=self.feedforward_dim, + activation=activation, device=device, dtype=dtype, ), @@ -300,6 +305,8 @@ class ViT(fl.Chain): norm_eps: float = 1e-6, mlp_ratio: int = 4, num_registers: int = 0, + activation: Activation = fl.GeLU(), + feedforward_dim: int | None = None, interpolate_antialias: bool = False, interpolate_mode: str = "bicubic", device: torch.device | str | None = None, @@ -316,6 +323,8 @@ class ViT(fl.Chain): norm_eps: The epsilon value for normalization. mlp_ratio: The ratio for the multi-layer perceptron (MLP). num_registers: The number of registers. + activation: The activation function. + feedforward_dim: The dimension of the feedforward layer. interpolate_antialias: Whether to use antialiasing for interpolation. interpolate_mode: The interpolation mode. device: The PyTorch device to use. @@ -330,6 +339,7 @@ class ViT(fl.Chain): self.norm_eps = norm_eps self.mlp_ratio = mlp_ratio self.num_registers = num_registers + self.feedforward_dim = feedforward_dim super().__init__( fl.Concatenate( @@ -370,6 +380,8 @@ class ViT(fl.Chain): Transformer( TransformerLayer( embedding_dim=embedding_dim, + feedforward_dim=feedforward_dim, + activation=activation, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_eps=norm_eps, diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index c3cf298..d994d57 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -9,6 +9,8 @@ from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_s from refiners.foundationals.dinov2.dinov2 import ( DINOv2_base, DINOv2_base_reg, + DINOv2_giant, + DINOv2_giant_reg, DINOv2_large, DINOv2_large_reg, DINOv2_small, @@ -23,9 +25,8 @@ FLAVORS_MAP = { "dinov2_vitb14_reg": DINOv2_base_reg, "dinov2_vitl14": DINOv2_large, "dinov2_vitl14_reg": DINOv2_large_reg, - # TODO: support giant flavors - # "dinov2_vitg14": DINOv2_giant, - # "dinov2_vitg14_reg": DINOv2_giant_reg, + "dinov2_vitg14": DINOv2_giant, + "dinov2_vitg14_reg": DINOv2_giant_reg, }