diff --git a/tests/fluxion/layers/test_converter.py b/tests/fluxion/layers/test_converter.py index fabe0ac..89350e2 100644 --- a/tests/fluxion/layers/test_converter.py +++ b/tests/fluxion/layers/test_converter.py @@ -1,20 +1,25 @@ import torch import pytest +from warnings import warn + import refiners.fluxion.layers as fl from refiners.fluxion.layers.chain import Distribute -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_converter_device_single_tensor() -> None: +def test_converter_device_single_tensor(test_device: torch.device) -> None: + if test_device.type != "cuda": + warn("only running on CUDA, skipping") + pytest.skip() + chain = fl.Chain( fl.Converter(set_device=True, set_dtype=False), - fl.Linear(in_features=1, out_features=1, device="cuda:0"), + fl.Linear(in_features=1, out_features=1, device=test_device), ) tensor = torch.randn(1, 1) converted_tensor = chain(tensor) - assert converted_tensor.device == torch.device(device="cuda:0") + assert converted_tensor.device == torch.device(device=test_device) def test_converter_dtype_single_tensor() -> None: @@ -29,13 +34,16 @@ def test_converter_dtype_single_tensor() -> None: assert converted_tensor.dtype == torch.float64 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_converter_multiple_tensors() -> None: +def test_converter_multiple_tensors(test_device: torch.device) -> None: + if test_device.type != "cuda": + warn("only running on CUDA, skipping") + pytest.skip() + chain = fl.Chain( fl.Converter(set_device=True, set_dtype=True), Distribute( - fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64), - fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64), + fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64), + fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64), ), ) @@ -44,9 +52,9 @@ def test_converter_multiple_tensors() -> None: converted_tensor1, converted_tensor2 = chain(tensor1, tensor2) - assert converted_tensor1.device == torch.device(device="cuda:0") + assert converted_tensor1.device == torch.device(device=test_device) assert converted_tensor1.dtype == torch.float64 - assert converted_tensor2.device == torch.device(device="cuda:0") + assert converted_tensor2.device == torch.device(device=test_device) assert converted_tensor2.dtype == torch.float64