mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
skip dinov2 float16 test on cpu + test dinov2 when batch_size>1
This commit is contained in:
parent
5f07fa9c21
commit
2ecf7e4b8c
|
@ -131,6 +131,10 @@ def test_dinov2_float16(
|
||||||
resolution: int,
|
resolution: int,
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
model = DINOv2_small(device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
|
@ -144,3 +148,22 @@ def test_dinov2_float16(
|
||||||
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||||
assert output.shape == (1, sequence_length, model.embedding_dim)
|
assert output.shape == (1, sequence_length, model.embedding_dim)
|
||||||
assert output.dtype == torch.float16
|
assert output.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_dinov2_batch_size(
|
||||||
|
resolution: int,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
model = DINOv2_small(device=test_device)
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
manual_seed(2)
|
||||||
|
input_data = torch.randn(
|
||||||
|
(batch_size, 3, resolution, resolution),
|
||||||
|
device=test_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model(input_data)
|
||||||
|
sequence_length = (resolution // model.patch_size) ** 2 + 1
|
||||||
|
assert output.shape == (batch_size, sequence_length, model.embedding_dim)
|
||||||
|
|
Loading…
Reference in a new issue