diff --git a/tests/fluxion/layers/test_converter.py b/tests/fluxion/layers/test_converter.py index 89350e2..611d4fa 100644 --- a/tests/fluxion/layers/test_converter.py +++ b/tests/fluxion/layers/test_converter.py @@ -3,7 +3,7 @@ import pytest from warnings import warn import refiners.fluxion.layers as fl -from refiners.fluxion.layers.chain import Distribute +from refiners.fluxion.layers.chain import ChainError, Distribute def test_converter_device_single_tensor(test_device: torch.device) -> None: @@ -66,5 +66,5 @@ def test_converter_no_parent_device_or_dtype() -> None: tensor = torch.randn(1, 1) - with pytest.raises(expected_exception=ValueError): + with pytest.raises(expected_exception=ChainError): chain(tensor)