diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 9f68d15..c3cf298 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -131,6 +131,10 @@ def test_dinov2_float16( resolution: int, test_device: torch.device, ) -> None: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + model = DINOv2_small(device=test_device, dtype=torch.float16) manual_seed(2) @@ -144,3 +148,22 @@ def test_dinov2_float16( sequence_length = (resolution // model.patch_size) ** 2 + 1 assert output.shape == (1, sequence_length, model.embedding_dim) assert output.dtype == torch.float16 + + +@no_grad() +def test_dinov2_batch_size( + resolution: int, + test_device: torch.device, +) -> None: + model = DINOv2_small(device=test_device) + + batch_size = 4 + manual_seed(2) + input_data = torch.randn( + (batch_size, 3, resolution, resolution), + device=test_device, + ) + + output = model(input_data) + sequence_length = (resolution // model.patch_size) ** 2 + 1 + assert output.shape == (batch_size, sequence_length, model.embedding_dim)