refiners/tests/fluxion/layers/test_converter.py

74 lines
2.2 KiB
Python
Raw Normal View History

from typing import Any, Callable
2023-09-06 16:53:16 +00:00
from warnings import warn
import pytest
import torch
2023-08-21 09:30:42 +00:00
import refiners.fluxion.layers as fl
from refiners.fluxion.layers.chain import ChainError, Distribute
2023-08-21 09:30:42 +00:00
2023-09-06 16:53:16 +00:00
def test_converter_device_single_tensor(test_device: torch.device) -> None:
if test_device.type != "cuda":
warn("only running on CUDA, skipping")
pytest.skip()
2023-08-21 09:30:42 +00:00
chain = fl.Chain(
fl.Converter(set_device=True, set_dtype=False),
2023-09-06 16:53:16 +00:00
fl.Linear(in_features=1, out_features=1, device=test_device),
2023-08-21 09:30:42 +00:00
)
tensor = torch.randn(1, 1)
converted_tensor = chain(tensor)
2023-09-06 16:53:16 +00:00
assert converted_tensor.device == torch.device(device=test_device)
2023-08-21 09:30:42 +00:00
def test_converter_dtype_single_tensor() -> None:
chain = fl.Chain(
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(in_features=1, out_features=1, dtype=torch.float64),
)
tensor = torch.randn(1, 1).to(dtype=torch.float32)
converted_tensor = chain(tensor)
assert converted_tensor.dtype == torch.float64
2023-09-06 16:53:16 +00:00
def test_converter_multiple_tensors(test_device: torch.device) -> None:
if test_device.type != "cuda":
warn("only running on CUDA, skipping")
pytest.skip()
2023-08-21 09:30:42 +00:00
chain = fl.Chain(
fl.Converter(set_device=True, set_dtype=True),
Distribute(
2023-09-06 16:53:16 +00:00
fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64),
fl.Linear(in_features=1, out_features=1, device=test_device, dtype=torch.float64),
2023-08-21 09:30:42 +00:00
),
)
tensor1 = torch.randn(1, 1)
tensor2 = torch.randn(1, 1)
converted_tensor1, converted_tensor2 = chain(tensor1, tensor2)
2023-09-06 16:53:16 +00:00
assert converted_tensor1.device == torch.device(device=test_device)
2023-08-21 09:30:42 +00:00
assert converted_tensor1.dtype == torch.float64
2023-09-06 16:53:16 +00:00
assert converted_tensor2.device == torch.device(device=test_device)
2023-08-21 09:30:42 +00:00
assert converted_tensor2.dtype == torch.float64
def test_converter_no_parent_device_or_dtype() -> None:
identity: Callable[[Any], Any] = lambda x: x
2023-08-21 09:30:42 +00:00
chain = fl.Chain(
fl.Lambda(func=identity),
2023-08-21 09:30:42 +00:00
fl.Converter(set_device=True, set_dtype=False),
)
tensor = torch.randn(1, 1)
with pytest.raises(expected_exception=ChainError):
2023-08-21 09:30:42 +00:00
chain(tensor)