From 8341d3a74ba9384f120d831c42e64caf09398d4f Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 30 Jan 2024 15:05:06 +0100 Subject: [PATCH] use float16, save memory --- tests/adapters/test_ip_adapter.py | 2 +- tests/adapters/test_t2i_adapter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/adapters/test_ip_adapter.py b/tests/adapters/test_ip_adapter.py index 455abd0..0fdc81b 100644 --- a/tests/adapters/test_ip_adapter.py +++ b/tests/adapters/test_ip_adapter.py @@ -27,7 +27,7 @@ def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter: @no_grad() @pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet]) def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device): - unet = k_unet(in_channels=4, device=test_device) + unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16) initial_repr = repr(unet) adapter = new_adapter(unet) assert repr(unet) == initial_repr diff --git a/tests/adapters/test_t2i_adapter.py b/tests/adapters/test_t2i_adapter.py index 5b1dca8..3cb836e 100644 --- a/tests/adapters/test_t2i_adapter.py +++ b/tests/adapters/test_t2i_adapter.py @@ -28,7 +28,7 @@ def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2 @no_grad() @pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet]) def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device): - unet = k_unet(in_channels=4, device=test_device) + unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16) initial_repr = repr(unet) adapter_1 = new_adapter(unet, "adapter 1") assert repr(unet) == initial_repr