From 2ecf7e4b8cd2ae5d547d6fa6d9328e448f280d45 Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 2 Apr 2024 16:39:28 +0000 Subject: [PATCH] skip dinov2 float16 test on cpu + test dinov2 when batch_size>1 --- tests/foundationals/dinov2/test_dinov2.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 9f68d15..c3cf298 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -131,6 +131,10 @@ def test_dinov2_float16( resolution: int, test_device: torch.device, ) -> None: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + model = DINOv2_small(device=test_device, dtype=torch.float16) manual_seed(2) @@ -144,3 +148,22 @@ def test_dinov2_float16( sequence_length = (resolution // model.patch_size) ** 2 + 1 assert output.shape == (1, sequence_length, model.embedding_dim) assert output.dtype == torch.float16 + + +@no_grad() +def test_dinov2_batch_size( + resolution: int, + test_device: torch.device, +) -> None: + model = DINOv2_small(device=test_device) + + batch_size = 4 + manual_seed(2) + input_data = torch.randn( + (batch_size, 3, resolution, resolution), + device=test_device, + ) + + output = model(input_data) + sequence_length = (resolution // model.patch_size) ** 2 + 1 + assert output.shape == (batch_size, sequence_length, model.embedding_dim)