do not hardcode a CUDA device in tests

This commit is contained in:
Pierre Chapuis 2023-09-06 18:53:16 +02:00
parent c55917e293
commit d54a38ae07

View file

@ -1,20 +1,25 @@
import torch import torch
import pytest import pytest
from warnings import warn
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.layers.chain import Distribute from refiners.fluxion.layers.chain import Distribute
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") def test_converter_device_single_tensor(test_device: torch.device) -> None:
def test_converter_device_single_tensor() -> None: if test_device.type != "cuda":
warn("only running on CUDA, skipping")
pytest.skip()
chain = fl.Chain( chain = fl.Chain(
fl.Converter(set_device=True, set_dtype=False), fl.Converter(set_device=True, set_dtype=False),
fl.Linear(in_features=1, out_features=1, device="cuda:0"), fl.Linear(in_features=1, out_features=1, device=test_device),
) )
tensor = torch.randn(1, 1) tensor = torch.randn(1, 1)
converted_tensor = chain(tensor) converted_tensor = chain(tensor)
assert converted_tensor.device == torch.device(device="cuda:0") assert converted_tensor.device == torch.device(device=test_device)
def test_converter_dtype_single_tensor() -> None: def test_converter_dtype_single_tensor() -> None:
@ -29,13 +34,16 @@ def test_converter_dtype_single_tensor() -> None:
assert converted_tensor.dtype == torch.float64 assert converted_tensor.dtype == torch.float64
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") def test_converter_multiple_tensors(test_device: torch.device) -> None:
def test_converter_multiple_tensors() -> None: if test_device.type != "cuda":
warn("only running on CUDA, skipping")
pytest.skip()
chain = fl.Chain( chain = fl.Chain(
fl.Converter(set_device=True, set_dtype=True), fl.Converter(set_device=True, set_dtype=True),
Distribute( Distribute(
fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64), fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64),
fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64), fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64),
), ),
) )
@ -44,9 +52,9 @@ def test_converter_multiple_tensors() -> None:
converted_tensor1, converted_tensor2 = chain(tensor1, tensor2) converted_tensor1, converted_tensor2 = chain(tensor1, tensor2)
assert converted_tensor1.device == torch.device(device="cuda:0") assert converted_tensor1.device == torch.device(device=test_device)
assert converted_tensor1.dtype == torch.float64 assert converted_tensor1.dtype == torch.float64
assert converted_tensor2.device == torch.device(device="cuda:0") assert converted_tensor2.device == torch.device(device=test_device)
assert converted_tensor2.dtype == torch.float64 assert converted_tensor2.dtype == torch.float64