mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix dinov2 interpolation, support batching
This commit is contained in:
parent
ef427538a6
commit
5f07fa9c21
|
@ -73,33 +73,31 @@ class InterpolateEmbedding(fl.Module):
|
|||
x: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cls_embed = x[:, :1, :] # -> (1, 1, D)
|
||||
patch_embed = x[:, 1:, :] # -> (1, N, D)
|
||||
cls_embed = x[:, :1, :] # -> (B, 1, D)
|
||||
patch_embed = x[:, 1:, :] # -> (B, N, D)
|
||||
|
||||
B = patch_embed.shape[0]
|
||||
N = patch_embed.shape[1]
|
||||
D = patch_embed.shape[2]
|
||||
M = int(sqrt(N))
|
||||
W = input.shape[2]
|
||||
H = input.shape[3]
|
||||
w = W // self.patch_size
|
||||
h = H // self.patch_size
|
||||
assert M * M == N, "The sequence length must be a square number."
|
||||
|
||||
patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
|
||||
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
|
||||
patch_embed = patch_embed.reshape(B, M, M, D) # -> (B, M, M, D)
|
||||
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (B, D, M, M)
|
||||
patch_embed = interpolate(
|
||||
x=patch_embed.to(dtype=torch.float32),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias,
|
||||
size=torch.Size(
|
||||
(
|
||||
W // self.patch_size,
|
||||
H // self.patch_size,
|
||||
)
|
||||
),
|
||||
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
|
||||
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
|
||||
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)
|
||||
size=torch.Size((w, h)),
|
||||
).to(dtype=cls_embed.dtype) # -> (B, D, w, h)
|
||||
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (B, w, h, D)
|
||||
patch_embed = patch_embed.reshape(B, -1, D) # -> (B, w*h, D)
|
||||
|
||||
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
|
||||
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (B, w*h+1, D)
|
||||
return x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue