From 2796117d2d8322a074b3a17226467893eefef7c9 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 9 Oct 2024 10:49:11 +0200 Subject: [PATCH] properly check for bfloat16 - we check only the test device, not the machine in general - we don't want emulated bfloat16 (e.g. CPU) --- src/refiners/fluxion/utils.py | 10 ++++++++++ tests/conftest.py | 8 ++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 5de2f99..899702b 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -304,3 +304,13 @@ def str_to_dtype(dtype: str) -> torch.dtype: return torch.bool case _: raise ValueError(f"Unknown dtype: {dtype}") + + +def device_has_bfloat16(device: torch.device) -> bool: + cuda_version = cast(str | None, torch.version.cuda) # type: ignore + if cuda_version is None or int(cuda_version.split(".")[0]) < 11: + return False + try: + return torch.cuda.get_device_properties(device).major >= 8 # type: ignore + except ValueError: + return False diff --git a/tests/conftest.py b/tests/conftest.py index ef5bad8..ee27658 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from typing import Callable import torch from pytest import FixtureRequest, fixture, skip -from refiners.fluxion.utils import str_to_dtype +from refiners.fluxion.utils import device_has_bfloat16, str_to_dtype PARENT_PATH = Path(__file__).parent @@ -21,11 +21,11 @@ def test_device() -> torch.device: return torch.device(test_device) -def dtype_fixture_factory(params: list[str]) -> Callable[[FixtureRequest], torch.dtype]: +def dtype_fixture_factory(params: list[str]) -> Callable[[torch.device, FixtureRequest], torch.dtype]: @fixture(scope="session", params=params) - def dtype_fixture(request: FixtureRequest) -> torch.dtype: + def dtype_fixture(test_device: torch.device, request: FixtureRequest) -> torch.dtype: torch_dtype = str_to_dtype(request.param) - if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + if torch_dtype == torch.bfloat16 and not device_has_bfloat16(test_device): skip("bfloat16 is not supported on this test device") return torch_dtype