move some tests into the adapters test folder

This commit is contained in:
Laurent 2024-09-08 13:56:36 +00:00
parent a51d695523
commit 60f6f62056
No known key found for this signature in database
6 changed files with 143 additions and 112 deletions

View file

@ -1,5 +1,3 @@
from typing import Iterator
import pytest import pytest
import torch import torch
@ -10,32 +8,46 @@ from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
@pytest.fixture(scope="module", params=[True, False]) @pytest.fixture(scope="module")
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]: def unet(
xl: bool = request.param refiners_unet: SD1UNet | SDXLUNet,
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4) ) -> SD1UNet | SDXLUNet:
yield unet return refiners_unet
def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None: def test_inject_eject_freeu(
unet: SD1UNet | SDXLUNet,
) -> None:
initial_repr = repr(unet)
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9]) freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0 assert unet.parent is None
assert unet.find(FreeUResidualConcatenator) is None
with pytest.raises(AssertionError) as exc: assert repr(unet) == initial_repr
freeu.eject()
assert "could not find" in str(exc.value)
freeu.inject() freeu.inject()
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2 assert unet.parent is not None
assert unet.find(FreeUResidualConcatenator) is not None
assert repr(unet) != initial_repr
freeu.eject() freeu.eject()
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0 assert unet.parent is None
assert unet.find(FreeUResidualConcatenator) is None
assert repr(unet) == initial_repr
freeu.inject()
assert unet.parent is not None
assert unet.find(FreeUResidualConcatenator) is not None
assert repr(unet) != initial_repr
freeu.eject()
assert unet.parent is None
assert unet.find(FreeUResidualConcatenator) is None
assert repr(unet) == initial_repr
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None: def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
num_blocks = len(unet.layer("UpBlocks", Chain)) num_blocks = len(unet.layer("UpBlocks", Chain))
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1)) SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
@ -43,15 +55,16 @@ def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None: def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9]) SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
with pytest.raises(AssertionError):
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2, 1.2], skip_scales=[0.9, 0.9])
def test_freeu_identity_scales() -> None: def test_freeu_identity_scales(unet: SD1UNet | SDXLUNet) -> None:
manual_seed(0) manual_seed(0)
text_embedding = torch.randn(1, 77, 768) text_embedding = torch.randn(1, 77, 768, dtype=unet.dtype, device=unet.device)
timestep = torch.randint(0, 999, size=(1, 1)) timestep = torch.randint(0, 999, size=(1, 1), device=unet.device)
x = torch.randn(1, 4, 32, 32) x = torch.randn(1, 4, 32, 32, dtype=unet.dtype, device=unet.device)
unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with no_grad(): with no_grad():
@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
unet.set_timestep(timestep=timestep) unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone()) y_2 = unet(x.clone())
freeu.eject()
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences # The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
assert torch.allclose(y_1, y_2, atol=1e-5) assert torch.allclose(y_1, y_2, atol=1e-5)

View file

@ -0,0 +1,107 @@
from hashlib import sha256
from pathlib import Path
from warnings import warn
import pytest
import torch
from huggingface_hub import hf_hub_download # type: ignore
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
@pytest.fixture
def manager(refiners_sd15: StableDiffusion_1) -> SDLoraManager:
return SDLoraManager(refiners_sd15)
@pytest.fixture
def pokemon_lora_weights(
test_weights_path: Path,
use_local_weights: bool,
) -> dict[str, torch.Tensor]:
if use_local_weights:
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
else:
weights_path = Path(
hf_hub_download(
repo_id="pcuenq/pokemon-lora",
filename="pytorch_lora_weights.bin",
revision="bc3cb5256ebc303457acab170ca6219a66dd31f5",
)
)
expected_sha256 = "f712fcfb6618da14d25a4f3e0c9460a878fc2417e2df95cdd683a73f71b50384"
retrieved_sha256 = sha256(weights_path.read_bytes()).hexdigest().lower()
assert retrieved_sha256 == expected_sha256, f"expected {expected_sha256}, got {retrieved_sha256}"
return load_tensors(weights_path)
def test_add_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
assert "pokemon-lora" in manager.names
with pytest.raises(AssertionError) as exc:
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
assert "already exists" in str(exc.value)
def test_add_multiple_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", pokemon_lora_weights)
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
assert "pokemon-lora" in manager.names
assert "pokemon-lora2" in manager.names
def test_remove_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", pokemon_lora_weights)
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
manager.remove_loras("pokemon-lora")
assert "pokemon-lora" not in manager.names
assert "pokemon-lora2" in manager.names
manager.remove_loras("pokemon-lora2")
assert "pokemon-lora2" not in manager.names
assert len(manager.names) == 0
def test_remove_all(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", pokemon_lora_weights)
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
manager.remove_all()
assert len(manager.names) == 0
def test_get_lora(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
def test_get_scale(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
assert manager.get_scale("pokemon-lora") == 0.4
def test_names(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
assert manager.names == []
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
assert manager.names == ["pokemon-lora"]
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights)
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
def test_scales(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
assert manager.scales == {}
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
assert manager.scales == {"pokemon-lora": 0.4}
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights, scale=0.5)
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}

View file

@ -1,91 +0,0 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from refiners.fluxion.utils import load_tensors
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
@pytest.fixture
def manager() -> SDLoraManager:
target = StableDiffusion_1()
return SDLoraManager(target)
@pytest.fixture
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
if not weights_path.is_file():
warn(f"could not find weights at {weights_path}, skipping")
pytest.skip(allow_module_level=True)
return load_tensors(weights_path)
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights)
assert "pokemon-lora" in manager.names
with pytest.raises(AssertionError) as exc:
manager.add_loras("pokemon-lora", tensors=weights)
assert "already exists" in str(exc.value)
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
assert "pokemon-lora" in manager.names
assert "pokemon-lora2" in manager.names
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
manager.remove_loras("pokemon-lora")
assert "pokemon-lora" not in manager.names
assert "pokemon-lora2" in manager.names
manager.remove_loras("pokemon-lora2")
assert "pokemon-lora2" not in manager.names
assert len(manager.names) == 0
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", weights)
manager.add_loras("pokemon-lora2", weights)
manager.remove_all()
assert len(manager.names) == 0
def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights)
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
assert manager.get_scale("pokemon-lora") == 0.4
def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
assert manager.names == []
manager.add_loras("pokemon-lora", tensors=weights)
assert manager.names == ["pokemon-lora"]
manager.add_loras("pokemon-lora2", tensors=weights)
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
assert manager.scales == {}
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
assert manager.scales == {"pokemon-lora": 0.4}
manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5)
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}