mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
do not hardcode a CUDA device in tests
This commit is contained in:
parent
c55917e293
commit
d54a38ae07
|
@ -1,20 +1,25 @@
|
|||
import torch
|
||||
import pytest
|
||||
from warnings import warn
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.layers.chain import Distribute
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
||||
def test_converter_device_single_tensor() -> None:
|
||||
def test_converter_device_single_tensor(test_device: torch.device) -> None:
|
||||
if test_device.type != "cuda":
|
||||
warn("only running on CUDA, skipping")
|
||||
pytest.skip()
|
||||
|
||||
chain = fl.Chain(
|
||||
fl.Converter(set_device=True, set_dtype=False),
|
||||
fl.Linear(in_features=1, out_features=1, device="cuda:0"),
|
||||
fl.Linear(in_features=1, out_features=1, device=test_device),
|
||||
)
|
||||
|
||||
tensor = torch.randn(1, 1)
|
||||
converted_tensor = chain(tensor)
|
||||
|
||||
assert converted_tensor.device == torch.device(device="cuda:0")
|
||||
assert converted_tensor.device == torch.device(device=test_device)
|
||||
|
||||
|
||||
def test_converter_dtype_single_tensor() -> None:
|
||||
|
@ -29,13 +34,16 @@ def test_converter_dtype_single_tensor() -> None:
|
|||
assert converted_tensor.dtype == torch.float64
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
||||
def test_converter_multiple_tensors() -> None:
|
||||
def test_converter_multiple_tensors(test_device: torch.device) -> None:
|
||||
if test_device.type != "cuda":
|
||||
warn("only running on CUDA, skipping")
|
||||
pytest.skip()
|
||||
|
||||
chain = fl.Chain(
|
||||
fl.Converter(set_device=True, set_dtype=True),
|
||||
Distribute(
|
||||
fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64),
|
||||
fl.Linear(in_features=1, out_features=1, device="cuda:0", dtype=torch.float64),
|
||||
fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64),
|
||||
fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64),
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -44,9 +52,9 @@ def test_converter_multiple_tensors() -> None:
|
|||
|
||||
converted_tensor1, converted_tensor2 = chain(tensor1, tensor2)
|
||||
|
||||
assert converted_tensor1.device == torch.device(device="cuda:0")
|
||||
assert converted_tensor1.device == torch.device(device=test_device)
|
||||
assert converted_tensor1.dtype == torch.float64
|
||||
assert converted_tensor2.device == torch.device(device="cuda:0")
|
||||
assert converted_tensor2.device == torch.device(device=test_device)
|
||||
assert converted_tensor2.dtype == torch.float64
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue