mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
skip dinov2 float16 test on cpu + test dinov2 when batch_size>1
This commit is contained in:
parent
5f07fa9c21
commit
2ecf7e4b8c
|
@ -131,6 +131,10 @@ def test_dinov2_float16(
|
|||
resolution: int,
|
||||
test_device: torch.device,
|
||||
) -> None:
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
||||
|
||||
manual_seed(2)
|
||||
|
@ -144,3 +148,22 @@ def test_dinov2_float16(
|
|||
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||
assert output.shape == (1, sequence_length, model.embedding_dim)
|
||||
assert output.dtype == torch.float16
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_dinov2_batch_size(
|
||||
resolution: int,
|
||||
test_device: torch.device,
|
||||
) -> None:
|
||||
model = DINOv2_small(device=test_device)
|
||||
|
||||
batch_size = 4
|
||||
manual_seed(2)
|
||||
input_data = torch.randn(
|
||||
(batch_size, 3, resolution, resolution),
|
||||
device=test_device,
|
||||
)
|
||||
|
||||
output = model(input_data)
|
||||
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||
assert output.shape == (batch_size, sequence_length, model.embedding_dim)
|
||||
|
|
Loading…
Reference in a new issue