properly check for bfloat16

- we check only the test device, not the machine in general
- we don't want emulated bfloat16 (e.g. CPU)
This commit is contained in:
Pierre Chapuis 2024-10-09 10:49:11 +02:00
parent f3d2b6c325
commit 2796117d2d
2 changed files with 14 additions and 4 deletions

View file

@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype:
return torch.bool return torch.bool
case _: case _:
raise ValueError(f"Unknown dtype: {dtype}") raise ValueError(f"Unknown dtype: {dtype}")
def device_has_bfloat16(device: torch.device) -> bool:
cuda_version = cast(str | None, torch.version.cuda) # type: ignore
if cuda_version is None or int(cuda_version.split(".")[0]) < 11:
return False
try:
return torch.cuda.get_device_properties(device).major >= 8 # type: ignore
except ValueError:
return False

View file

@ -5,7 +5,7 @@ from typing import Callable
import torch import torch
from pytest import FixtureRequest, fixture, skip from pytest import FixtureRequest, fixture, skip
from refiners.fluxion.utils import str_to_dtype from refiners.fluxion.utils import device_has_bfloat16, str_to_dtype
PARENT_PATH = Path(__file__).parent PARENT_PATH = Path(__file__).parent
@ -21,11 +21,11 @@ def test_device() -> torch.device:
return torch.device(test_device) return torch.device(test_device)
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]: def dtype_fixture_factory(params: list[str]) -> Callable[[torch.device, FixtureRequest], torch.dtype]:
@fixture(scope="session", params=params) @fixture(scope="session", params=params)
def dtype_fixture(request: FixtureRequest) -> torch.dtype: def dtype_fixture(test_device: torch.device, request: FixtureRequest) -> torch.dtype:
torch_dtype = str_to_dtype(request.param) torch_dtype = str_to_dtype(request.param)
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): if torch_dtype == torch.bfloat16 and not device_has_bfloat16(test_device):
skip("bfloat16 is not supported on this test device") skip("bfloat16 is not supported on this test device")
return torch_dtype return torch_dtype