mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
use float16, save memory
This commit is contained in:
parent
d185711bc5
commit
8341d3a74b
|
@ -27,7 +27,7 @@ def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
|
||||||
@no_grad()
|
@no_grad()
|
||||||
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
|
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
|
||||||
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
|
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
|
||||||
unet = k_unet(in_channels=4, device=test_device)
|
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
|
||||||
initial_repr = repr(unet)
|
initial_repr = repr(unet)
|
||||||
adapter = new_adapter(unet)
|
adapter = new_adapter(unet)
|
||||||
assert repr(unet) == initial_repr
|
assert repr(unet) == initial_repr
|
||||||
|
|
|
@ -28,7 +28,7 @@ def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2
|
||||||
@no_grad()
|
@no_grad()
|
||||||
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
|
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
|
||||||
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
|
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
|
||||||
unet = k_unet(in_channels=4, device=test_device)
|
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
|
||||||
initial_repr = repr(unet)
|
initial_repr = repr(unet)
|
||||||
adapter_1 = new_adapter(unet, "adapter 1")
|
adapter_1 = new_adapter(unet, "adapter 1")
|
||||||
assert repr(unet) == initial_repr
|
assert repr(unet) == initial_repr
|
||||||
|
|
Loading…
Reference in a new issue