add various torch.dtype test fixtures

This commit is contained in:
Laurent 2024-10-03 08:46:46 +00:00 committed by Laureηt
parent 16714e6745
commit b20474f8f5

View file

@ -1,8 +1,11 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable
import torch import torch
from pytest import fixture from pytest import FixtureRequest, fixture, skip
from refiners.fluxion.utils import str_to_dtype
PARENT_PATH = Path(__file__).parent PARENT_PATH = Path(__file__).parent
@ -18,6 +21,23 @@ def test_device() -> torch.device:
return torch.device(test_device) return torch.device(test_device)
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
@fixture(scope="session", params=params)
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
torch_dtype = str_to_dtype(request.param)
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
skip("bfloat16 is not supported on this test device")
return torch_dtype
return dtype_fixture
test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"])
test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"])
test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"])
test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"])
@fixture(scope="session") @fixture(scope="session")
def test_weights_path() -> Path: def test_weights_path() -> Path:
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR") from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")