fix dinov2 interpolation, support batching

This commit is contained in:
Laurent 2024-04-02 16:39:05 +00:00 committed by Laureηt
parent ef427538a6
commit 5f07fa9c21

View file

@ -73,33 +73,31 @@ class InterpolateEmbedding(fl.Module):
x: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
cls_embed = x[:, :1, :] # -> (1, 1, D)
patch_embed = x[:, 1:, :] # -> (1, N, D)
cls_embed = x[:, :1, :] # -> (B, 1, D)
patch_embed = x[:, 1:, :] # -> (B, N, D)
B = patch_embed.shape[0]
N = patch_embed.shape[1]
D = patch_embed.shape[2]
M = int(sqrt(N))
W = input.shape[2]
H = input.shape[3]
w = W // self.patch_size
h = H // self.patch_size
assert M * M == N, "The sequence length must be a square number."
patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
patch_embed = patch_embed.reshape(B, M, M, D) # -> (B, M, M, D)
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (B, D, M, M)
patch_embed = interpolate(
x=patch_embed.to(dtype=torch.float32),
mode=self.mode,
antialias=self.antialias,
size=torch.Size(
(
W // self.patch_size,
H // self.patch_size,
)
),
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)
size=torch.Size((w, h)),
).to(dtype=cls_embed.dtype) # -> (B, D, w, h)
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (B, w, h, D)
patch_embed = patch_embed.reshape(B, -1, D) # -> (B, w*h, D)
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (B, w*h+1, D)
return x