2024-06-24 14:27:30 +00:00
|
|
|
import gc
|
2024-03-21 13:59:36 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from warnings import warn
|
|
|
|
|
|
|
|
from pytest import fixture, skip
|
|
|
|
|
|
|
|
|
2024-06-24 14:27:30 +00:00
|
|
|
@fixture(autouse=True)
|
|
|
|
def ensure_gc():
|
|
|
|
# Avoid GPU OOMs
|
|
|
|
# See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
2024-03-21 13:59:36 +00:00
|
|
|
@fixture(scope="package")
|
|
|
|
def ref_path(test_sam_path: Path) -> Path:
|
|
|
|
return test_sam_path / "test_sam_ref"
|
|
|
|
|
|
|
|
|
|
|
|
@fixture(scope="package")
|
|
|
|
def sam_h_weights(test_weights_path: Path) -> Path:
|
|
|
|
sam_h_weights = test_weights_path / "segment-anything-h.safetensors"
|
|
|
|
if not sam_h_weights.is_file():
|
|
|
|
warn(f"could not find weights at {sam_h_weights}, skipping")
|
|
|
|
skip(allow_module_level=True)
|
|
|
|
return sam_h_weights
|