diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index d59c627..0e1a6c5 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1,3 +1,4 @@ +import gc from pathlib import Path from typing import Iterator from warnings import warn @@ -30,6 +31,13 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import St from tests.utils import ensure_similar_images +@pytest.fixture(autouse=True) +def ensure_gc(): + # Avoid GPU OOMs + # See https://github.com/pytest-dev/pytest/discussions/8153#discussioncomment-214812 + gc.collect() + + @pytest.fixture(scope="module") def ref_path(test_e2e_path: Path) -> Path: return test_e2e_path / "test_diffusion_ref"