mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-15 01:28:14 +00:00
properly check for bfloat16
- we check only the test device, not the machine in general - we don't want emulated bfloat16 (e.g. CPU)
This commit is contained in:
parent
f3d2b6c325
commit
2796117d2d
|
@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype:
|
||||||
return torch.bool
|
return torch.bool
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown dtype: {dtype}")
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def device_has_bfloat16(device: torch.device) -> bool:
|
||||||
|
cuda_version = cast(str | None, torch.version.cuda) # type: ignore
|
||||||
|
if cuda_version is None or int(cuda_version.split(".")[0]) < 11:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return torch.cuda.get_device_properties(device).major >= 8 # type: ignore
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Callable
|
||||||
import torch
|
import torch
|
||||||
from pytest import FixtureRequest, fixture, skip
|
from pytest import FixtureRequest, fixture, skip
|
||||||
|
|
||||||
from refiners.fluxion.utils import str_to_dtype
|
from refiners.fluxion.utils import device_has_bfloat16, str_to_dtype
|
||||||
|
|
||||||
PARENT_PATH = Path(__file__).parent
|
PARENT_PATH = Path(__file__).parent
|
||||||
|
|
||||||
|
@ -21,11 +21,11 @@ def test_device() -> torch.device:
|
||||||
return torch.device(test_device)
|
return torch.device(test_device)
|
||||||
|
|
||||||
|
|
||||||
def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]:
|
def dtype_fixture_factory(params: list[str]) -> Callable[[torch.device, FixtureRequest], torch.dtype]:
|
||||||
@fixture(scope="session", params=params)
|
@fixture(scope="session", params=params)
|
||||||
def dtype_fixture(request: FixtureRequest) -> torch.dtype:
|
def dtype_fixture(test_device: torch.device, request: FixtureRequest) -> torch.dtype:
|
||||||
torch_dtype = str_to_dtype(request.param)
|
torch_dtype = str_to_dtype(request.param)
|
||||||
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
if torch_dtype == torch.bfloat16 and not device_has_bfloat16(test_device):
|
||||||
skip("bfloat16 is not supported on this test device")
|
skip("bfloat16 is not supported on this test device")
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue