From b20474f8f5a4fee451500dc6f05ed11c01b6165a Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 3 Oct 2024 08:46:46 +0000 Subject: [PATCH] add various torch.dtype test fixtures --- tests/conftest.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ba56acb..ef5bad8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ import os from pathlib import Path +from typing import Callable import torch -from pytest import fixture +from pytest import FixtureRequest, fixture, skip + +from refiners.fluxion.utils import str_to_dtype PARENT_PATH = Path(__file__).parent @@ -18,6 +21,23 @@ def test_device() -> torch.device: return torch.device(test_device) +def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]: + @fixture(scope="session", params=params) + def dtype_fixture(request: FixtureRequest) -> torch.dtype: + torch_dtype = str_to_dtype(request.param) + if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + skip("bfloat16 is not supported on this test device") + return torch_dtype + + return dtype_fixture + + +test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"]) +test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"]) +test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"]) +test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"]) + + @fixture(scope="session") def test_weights_path() -> Path: from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")