mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
use float16, save memory
This commit is contained in:
parent
d185711bc5
commit
8341d3a74b
|
@ -27,7 +27,7 @@ def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
|
|||
@no_grad()
|
||||
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
|
||||
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
|
||||
unet = k_unet(in_channels=4, device=test_device)
|
||||
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
|
||||
initial_repr = repr(unet)
|
||||
adapter = new_adapter(unet)
|
||||
assert repr(unet) == initial_repr
|
||||
|
|
|
@ -28,7 +28,7 @@ def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2
|
|||
@no_grad()
|
||||
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
|
||||
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
|
||||
unet = k_unet(in_channels=4, device=test_device)
|
||||
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
|
||||
initial_repr = repr(unet)
|
||||
adapter_1 = new_adapter(unet, "adapter 1")
|
||||
assert repr(unet) == initial_repr
|
||||
|
|
Loading…
Reference in a new issue