skip dinov2 float16 test on cpu + test dinov2 when batch_size>1

This commit is contained in:
Laurent 2024-04-02 16:39:28 +00:00 committed by Laureηt
parent 5f07fa9c21
commit 2ecf7e4b8c

View file

@ -131,6 +131,10 @@ def test_dinov2_float16(
resolution: int,
test_device: torch.device,
) -> None:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
model = DINOv2_small(device=test_device, dtype=torch.float16)
manual_seed(2)
@ -144,3 +148,22 @@ def test_dinov2_float16(
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16
@no_grad()
def test_dinov2_batch_size(
resolution: int,
test_device: torch.device,
) -> None:
model = DINOv2_small(device=test_device)
batch_size = 4
manual_seed(2)
input_data = torch.randn(
(batch_size, 3, resolution, resolution),
device=test_device,
)
output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (batch_size, sequence_length, model.embedding_dim)