refiners/tests/adapters/test_ip_adapter.py
2024-02-01 16:17:07 +01:00

38 lines
982 B
Python

from typing import overload
import pytest
import torch
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
@overload
def new_adapter(target: SD1UNet) -> SD1IPAdapter:
...
@overload
def new_adapter(target: SDXLUNet) -> SDXLIPAdapter:
...
def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
if isinstance(target, SD1UNet):
return SD1IPAdapter(target=target)
else:
return SDXLIPAdapter(target=target)
@no_grad()
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
unet = k_unet(in_channels=4, device=test_device)
initial_repr = repr(unet)
adapter = new_adapter(unet)
assert repr(unet) == initial_repr
adapter.inject()
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr