add tests based on repr for inject / eject

This commit is contained in:
Pierre Chapuis 2024-01-30 14:16:45 +01:00 committed by Cédric Deltheil
parent 0e77ef1720
commit d185711bc5
3 changed files with 48 additions and 1 deletions

View file

@ -11,6 +11,8 @@ from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@pytest.mark.parametrize("k_encoder", [CLIPTextEncoderL])
def test_inject_eject(k_encoder: type[CLIPTextEncoder], test_device: torch.device):
encoder = k_encoder(device=test_device)
initial_repr = repr(encoder)
extender = ConceptExtender(encoder)
cat_embedding = torch.randn((encoder.embedding_dim,), device=test_device)
@ -18,7 +20,9 @@ def test_inject_eject(k_encoder: type[CLIPTextEncoder], test_device: torch.devic
extender_2 = ConceptExtender(encoder)
assert repr(encoder) == initial_repr
extender.inject()
assert repr(encoder) != initial_repr
with pytest.raises(AssertionError) as no_nesting:
extender_2.inject()
@ -34,6 +38,7 @@ def test_inject_eject(k_encoder: type[CLIPTextEncoder], test_device: torch.devic
extender_2.inject().eject()
ConceptExtender(encoder) # no exception
assert repr(encoder) == initial_repr
tokenizer = encoder.ensure_find(CLIPTokenizer)
assert len(tokenizer.encode("<token1>")) > 3

View file

@ -0,0 +1,37 @@
from typing import overload
import pytest
import torch
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
@overload
def new_adapter(target: SD1UNet) -> SD1IPAdapter:
...
@overload
def new_adapter(target: SDXLUNet) -> SDXLIPAdapter:
...
def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
if isinstance(target, SD1UNet):
return SD1IPAdapter(target=target)
else:
return SDXLIPAdapter(target=target)
@no_grad()
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
unet = k_unet(in_channels=4, device=test_device)
initial_repr = repr(unet)
adapter = new_adapter(unet)
assert repr(unet) == initial_repr
adapter.inject()
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr

View file

@ -29,7 +29,11 @@ def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
unet = k_unet(in_channels=4, device=test_device)
adapter_1 = new_adapter(unet, "adapter 1").inject()
initial_repr = repr(unet)
adapter_1 = new_adapter(unet, "adapter 1")
assert repr(unet) == initial_repr
adapter_1.inject()
assert repr(unet) != initial_repr
with pytest.raises(AssertionError) as already_injected_error:
new_adapter(unet, "adapter 1").inject()
@ -50,3 +54,4 @@ def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch
assert unet.parent is None
assert unet.find(T2IFeatures) is None
assert repr(unet) == initial_repr