mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
fix typing issue in dinov2 test
This commit is contained in:
parent
336253f26b
commit
73089a4e2d
|
@ -60,13 +60,14 @@ def ref_model(
|
||||||
if "reg" not in flavor:
|
if "reg" not in flavor:
|
||||||
kwargs["interpolate_offset"] = 0.0
|
kwargs["interpolate_offset"] = 0.0
|
||||||
|
|
||||||
model = torch.hub.load( # type: ignore
|
model: torch.nn.Module = torch.hub.load( # type: ignore
|
||||||
model=flavor,
|
model=flavor,
|
||||||
repo_or_dir=str(dinov2_repo_path),
|
repo_or_dir=str(dinov2_repo_path),
|
||||||
source="local",
|
source="local",
|
||||||
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
|
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
|
||||||
**kwargs,
|
**kwargs,
|
||||||
).to(device=test_device)
|
)
|
||||||
|
model = model.to(device=test_device)
|
||||||
|
|
||||||
flavor = flavor.replace("_reg", "_reg4")
|
flavor = flavor.replace("_reg", "_reg4")
|
||||||
weights = test_weights_path / f"{flavor}_pretrain.pth"
|
weights = test_weights_path / f"{flavor}_pretrain.pth"
|
||||||
|
|
Loading…
Reference in a new issue