mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 05:38:46 +00:00
add various torch.dtype test fixtures
This commit is contained in:
parent
16714e6745
commit
b20474f8f5
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from pytest import fixture
|
||||
from pytest import FixtureRequest, fixture, skip
|
||||
|
||||
from refiners.fluxion.utils import str_to_dtype
|
||||
|
||||
PARENT_PATH = Path(__file__).parent
|
||||
|
||||
|
@ -18,6 +21,23 @@ def test_device() -> torch.device:
|
|||
return torch.device(test_device)
|
||||
|
||||
|
||||
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
|
||||
@fixture(scope="session", params=params)
|
||||
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
|
||||
torch_dtype = str_to_dtype(request.param)
|
||||
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||
skip("bfloat16 is not supported on this test device")
|
||||
return torch_dtype
|
||||
|
||||
return dtype_fixture
|
||||
|
||||
|
||||
test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"])
|
||||
test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"])
|
||||
test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"])
|
||||
test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"])
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
def test_weights_path() -> Path:
|
||||
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")
|
||||
|
|
Loading…
Reference in a new issue