mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix dinov2 interpolation, support batching
This commit is contained in:
parent
ef427538a6
commit
5f07fa9c21
|
@ -73,33 +73,31 @@ class InterpolateEmbedding(fl.Module):
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
cls_embed = x[:, :1, :] # -> (1, 1, D)
|
cls_embed = x[:, :1, :] # -> (B, 1, D)
|
||||||
patch_embed = x[:, 1:, :] # -> (1, N, D)
|
patch_embed = x[:, 1:, :] # -> (B, N, D)
|
||||||
|
|
||||||
|
B = patch_embed.shape[0]
|
||||||
N = patch_embed.shape[1]
|
N = patch_embed.shape[1]
|
||||||
D = patch_embed.shape[2]
|
D = patch_embed.shape[2]
|
||||||
M = int(sqrt(N))
|
M = int(sqrt(N))
|
||||||
W = input.shape[2]
|
W = input.shape[2]
|
||||||
H = input.shape[3]
|
H = input.shape[3]
|
||||||
|
w = W // self.patch_size
|
||||||
|
h = H // self.patch_size
|
||||||
assert M * M == N, "The sequence length must be a square number."
|
assert M * M == N, "The sequence length must be a square number."
|
||||||
|
|
||||||
patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
|
patch_embed = patch_embed.reshape(B, M, M, D) # -> (B, M, M, D)
|
||||||
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
|
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (B, D, M, M)
|
||||||
patch_embed = interpolate(
|
patch_embed = interpolate(
|
||||||
x=patch_embed.to(dtype=torch.float32),
|
x=patch_embed.to(dtype=torch.float32),
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
antialias=self.antialias,
|
antialias=self.antialias,
|
||||||
size=torch.Size(
|
size=torch.Size((w, h)),
|
||||||
(
|
).to(dtype=cls_embed.dtype) # -> (B, D, w, h)
|
||||||
W // self.patch_size,
|
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (B, w, h, D)
|
||||||
H // self.patch_size,
|
patch_embed = patch_embed.reshape(B, -1, D) # -> (B, w*h, D)
|
||||||
)
|
|
||||||
),
|
|
||||||
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
|
|
||||||
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
|
|
||||||
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)
|
|
||||||
|
|
||||||
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
|
x = torch.cat((cls_embed, patch_embed), dim=1) # -> (B, w*h+1, D)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue