diff --git a/tests/adapters/test_ip_adapter.py b/tests/adapters/test_ip_adapter.py index 455abd0..0fdc81b 100644 --- a/tests/adapters/test_ip_adapter.py +++ b/tests/adapters/test_ip_adapter.py @@ -27,7 +27,7 @@ def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter: @no_grad() @pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet]) def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device): - unet = k_unet(in_channels=4, device=test_device) + unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16) initial_repr = repr(unet) adapter = new_adapter(unet) assert repr(unet) == initial_repr diff --git a/tests/adapters/test_t2i_adapter.py b/tests/adapters/test_t2i_adapter.py index 5b1dca8..3cb836e 100644 --- a/tests/adapters/test_t2i_adapter.py +++ b/tests/adapters/test_t2i_adapter.py @@ -28,7 +28,7 @@ def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2 @no_grad() @pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet]) def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device): - unet = k_unet(in_channels=4, device=test_device) + unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16) initial_repr = repr(unet) adapter_1 = new_adapter(unet, "adapter 1") assert repr(unet) == initial_repr