diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index d994d57..47f7959 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -60,13 +60,14 @@ def ref_model( if "reg" not in flavor: kwargs["interpolate_offset"] = 0.0 - model = torch.hub.load( # type: ignore + model: torch.nn.Module = torch.hub.load( # type: ignore model=flavor, repo_or_dir=str(dinov2_repo_path), source="local", pretrained=False, # to turn off automatic weights download (see load_state_dict below) **kwargs, - ).to(device=test_device) + ) + model = model.to(device=test_device) flavor = flavor.replace("_reg", "_reg4") weights = test_weights_path / f"{flavor}_pretrain.pth"