test_converter: use proper exception type

Follow up of #102
This commit is contained in:
Cédric Deltheil 2023-10-18 12:33:40 +02:00 committed by Cédric Deltheil
parent 6ddd901767
commit 46dd710076

View file

@ -3,7 +3,7 @@ import pytest
from warnings import warn
import refiners.fluxion.layers as fl
from refiners.fluxion.layers.chain import Distribute
from refiners.fluxion.layers.chain import ChainError, Distribute
def test_converter_device_single_tensor(test_device: torch.device) -> None:
@ -66,5 +66,5 @@ def test_converter_no_parent_device_or_dtype() -> None:
tensor = torch.randn(1, 1)
with pytest.raises(expected_exception=ValueError):
with pytest.raises(expected_exception=ChainError):
chain(tensor)