add various torch.dtype test fixtures

This commit is contained in:
Laurent 2024-10-03 08:46:46 +00:00 committed by Laureηt
parent 16714e6745
commit b20474f8f5

View file

@ -1,8 +1,11 @@
import os
from pathlib import Path
from typing import Callable
import torch
from pytest import fixture
from pytest import FixtureRequest, fixture, skip
from refiners.fluxion.utils import str_to_dtype
PARENT_PATH = Path(__file__).parent
@ -18,6 +21,23 @@ def test_device() -> torch.device:
return torch.device(test_device)
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
@fixture(scope="session", params=params)
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
torch_dtype = str_to_dtype(request.param)
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
skip("bfloat16 is not supported on this test device")
return torch_dtype
return dtype_fixture
test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"])
test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"])
test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"])
test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"])
@fixture(scope="session")
def test_weights_path() -> Path:
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")