diff --git a/src/refiners/foundationals/dinov2/vit.py b/src/refiners/foundationals/dinov2/vit.py index e2443ef..4808d6f 100644 --- a/src/refiners/foundationals/dinov2/vit.py +++ b/src/refiners/foundationals/dinov2/vit.py @@ -73,33 +73,31 @@ class InterpolateEmbedding(fl.Module): x: torch.Tensor, input: torch.Tensor, ) -> torch.Tensor: - cls_embed = x[:, :1, :] # -> (1, 1, D) - patch_embed = x[:, 1:, :] # -> (1, N, D) + cls_embed = x[:, :1, :] # -> (B, 1, D) + patch_embed = x[:, 1:, :] # -> (B, N, D) + B = patch_embed.shape[0] N = patch_embed.shape[1] D = patch_embed.shape[2] M = int(sqrt(N)) W = input.shape[2] H = input.shape[3] + w = W // self.patch_size + h = H // self.patch_size assert M * M == N, "The sequence length must be a square number." - patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D) - patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M) + patch_embed = patch_embed.reshape(B, M, M, D) # -> (B, M, M, D) + patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (B, D, M, M) patch_embed = interpolate( x=patch_embed.to(dtype=torch.float32), mode=self.mode, antialias=self.antialias, - size=torch.Size( - ( - W // self.patch_size, - H // self.patch_size, - ) - ), - ).to(dtype=cls_embed.dtype) # -> (1, D, w, h) - patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D) - patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D) + size=torch.Size((w, h)), + ).to(dtype=cls_embed.dtype) # -> (B, D, w, h) + patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (B, w, h, D) + patch_embed = patch_embed.reshape(B, -1, D) # -> (B, w*h, D) - x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D) + x = torch.cat((cls_embed, patch_embed), dim=1) # -> (B, w*h+1, D) return x