mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add various torch.dtype test fixtures
This commit is contained in:
parent
16714e6745
commit
b20474f8f5
|
@ -1,8 +1,11 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytest import fixture
|
from pytest import FixtureRequest, fixture, skip
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import str_to_dtype
|
||||||
|
|
||||||
PARENT_PATH = Path(__file__).parent
|
PARENT_PATH = Path(__file__).parent
|
||||||
|
|
||||||
|
@ -18,6 +21,23 @@ def test_device() -> torch.device:
|
||||||
return torch.device(test_device)
|
return torch.device(test_device)
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
|
||||||
|
@fixture(scope="session", params=params)
|
||||||
|
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
|
||||||
|
torch_dtype = str_to_dtype(request.param)
|
||||||
|
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||||
|
skip("bfloat16 is not supported on this test device")
|
||||||
|
return torch_dtype
|
||||||
|
|
||||||
|
return dtype_fixture
|
||||||
|
|
||||||
|
|
||||||
|
test_dtype_fp32_bf16_fp16 = dtype_fixture_factory(["float32", "bfloat16", "float16"])
|
||||||
|
test_dtype_fp32_fp16 = dtype_fixture_factory(["float32", "float16"])
|
||||||
|
test_dtype_fp32_bf16 = dtype_fixture_factory(["float32", "bfloat16"])
|
||||||
|
test_dtype_fp16_bf16 = dtype_fixture_factory(["float16", "bfloat16"])
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope="session")
|
@fixture(scope="session")
|
||||||
def test_weights_path() -> Path:
|
def test_weights_path() -> Path:
|
||||||
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")
|
from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR")
|
||||||
|
|
Loading…
Reference in a new issue