mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
properly check for bfloat16
- we check only the test device, not the machine in general - we don't want emulated bfloat16 (e.g. CPU)
This commit is contained in:
parent
f3d2b6c325
commit
2796117d2d
|
@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype:
|
|||
return torch.bool
|
||||
case _:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
|
||||
def device_has_bfloat16(device: torch.device) -> bool:
|
||||
cuda_version = cast(str | None, torch.version.cuda) # type: ignore
|
||||
if cuda_version is None or int(cuda_version.split(".")[0]) < 11:
|
||||
return False
|
||||
try:
|
||||
return torch.cuda.get_device_properties(device).major >= 8 # type: ignore
|
||||
except ValueError:
|
||||
return False
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Callable
|
|||
import torch
|
||||
from pytest import FixtureRequest, fixture, skip
|
||||
|
||||
from refiners.fluxion.utils import str_to_dtype
|
||||
from refiners.fluxion.utils import device_has_bfloat16, str_to_dtype
|
||||
|
||||
PARENT_PATH = Path(__file__).parent
|
||||
|
||||
|
@ -21,11 +21,11 @@ def test_device() -> torch.device:
|
|||
return torch.device(test_device)
|
||||
|
||||
|
||||
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
|
||||
def dtype_fixture_factory(params: list[str]) -> Callable[[torch.device, FixtureRequest], torch.dtype]:
|
||||
@fixture(scope="session", params=params)
|
||||
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
|
||||
def dtype_fixture(test_device: torch.device, request: FixtureRequest) -> torch.dtype:
|
||||
torch_dtype = str_to_dtype(request.param)
|
||||
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||
if torch_dtype == torch.bfloat16 and not device_has_bfloat16(test_device):
|
||||
skip("bfloat16 is not supported on this test device")
|
||||
return torch_dtype
|
||||
|
||||
|
|
Loading…
Reference in a new issue