use float16, save memory

This commit is contained in:
Pierre Chapuis 2024-01-30 15:05:06 +01:00 committed by Cédric Deltheil
parent d185711bc5
commit 8341d3a74b
2 changed files with 2 additions and 2 deletions

View file

@ -27,7 +27,7 @@ def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
@no_grad()
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
unet = k_unet(in_channels=4, device=test_device)
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
initial_repr = repr(unet)
adapter = new_adapter(unet)
assert repr(unet) == initial_repr

View file

@ -28,7 +28,7 @@ def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2
@no_grad()
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
unet = k_unet(in_channels=4, device=test_device)
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
initial_repr = repr(unet)
adapter_1 = new_adapter(unet, "adapter 1")
assert repr(unet) == initial_repr