mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-09-18 18:35:28 +00:00
move some tests into the adapters test folder
This commit is contained in:
parent
a51d695523
commit
60f6f62056
|
@ -1,5 +1,3 @@
|
|||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -10,32 +8,46 @@ from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
|||
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]:
|
||||
xl: bool = request.param
|
||||
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4)
|
||||
yield unet
|
||||
@pytest.fixture(scope="module")
|
||||
def unet(
|
||||
refiners_unet: SD1UNet | SDXLUNet,
|
||||
) -> SD1UNet | SDXLUNet:
|
||||
return refiners_unet
|
||||
|
||||
|
||||
def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
|
||||
def test_inject_eject_freeu(
|
||||
unet: SD1UNet | SDXLUNet,
|
||||
) -> None:
|
||||
initial_repr = repr(unet)
|
||||
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
|
||||
|
||||
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
freeu.eject()
|
||||
assert "could not find" in str(exc.value)
|
||||
assert unet.parent is None
|
||||
assert unet.find(FreeUResidualConcatenator) is None
|
||||
assert repr(unet) == initial_repr
|
||||
|
||||
freeu.inject()
|
||||
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2
|
||||
assert unet.parent is not None
|
||||
assert unet.find(FreeUResidualConcatenator) is not None
|
||||
assert repr(unet) != initial_repr
|
||||
|
||||
freeu.eject()
|
||||
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
|
||||
assert unet.parent is None
|
||||
assert unet.find(FreeUResidualConcatenator) is None
|
||||
assert repr(unet) == initial_repr
|
||||
|
||||
freeu.inject()
|
||||
assert unet.parent is not None
|
||||
assert unet.find(FreeUResidualConcatenator) is not None
|
||||
assert repr(unet) != initial_repr
|
||||
|
||||
freeu.eject()
|
||||
assert unet.parent is None
|
||||
assert unet.find(FreeUResidualConcatenator) is None
|
||||
assert repr(unet) == initial_repr
|
||||
|
||||
|
||||
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||
num_blocks = len(unet.layer("UpBlocks", Chain))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
|
||||
|
||||
|
@ -43,15 +55,16 @@ def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
|||
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
|
||||
with pytest.raises(AssertionError):
|
||||
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2, 1.2], skip_scales=[0.9, 0.9])
|
||||
|
||||
|
||||
def test_freeu_identity_scales() -> None:
|
||||
def test_freeu_identity_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||
manual_seed(0)
|
||||
text_embedding = torch.randn(1, 77, 768)
|
||||
timestep = torch.randint(0, 999, size=(1, 1))
|
||||
x = torch.randn(1, 4, 32, 32)
|
||||
text_embedding = torch.randn(1, 77, 768, dtype=unet.dtype, device=unet.device)
|
||||
timestep = torch.randint(0, 999, size=(1, 1), device=unet.device)
|
||||
x = torch.randn(1, 4, 32, 32, dtype=unet.dtype, device=unet.device)
|
||||
|
||||
unet = SD1UNet(in_channels=4)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||
|
||||
with no_grad():
|
||||
|
@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
|
|||
unet.set_timestep(timestep=timestep)
|
||||
y_2 = unet(x.clone())
|
||||
|
||||
freeu.eject()
|
||||
|
||||
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
|
||||
assert torch.allclose(y_1, y_2, atol=1e-5)
|
107
tests/adapters/test_lora_manager.py
Normal file
107
tests/adapters/test_lora_manager.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_tensors
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(refiners_sd15: StableDiffusion_1) -> SDLoraManager:
|
||||
return SDLoraManager(refiners_sd15)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pokemon_lora_weights(
|
||||
test_weights_path: Path,
|
||||
use_local_weights: bool,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if use_local_weights:
|
||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||
if not weights_path.is_file():
|
||||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
else:
|
||||
weights_path = Path(
|
||||
hf_hub_download(
|
||||
repo_id="pcuenq/pokemon-lora",
|
||||
filename="pytorch_lora_weights.bin",
|
||||
revision="bc3cb5256ebc303457acab170ca6219a66dd31f5",
|
||||
)
|
||||
)
|
||||
|
||||
expected_sha256 = "f712fcfb6618da14d25a4f3e0c9460a878fc2417e2df95cdd683a73f71b50384"
|
||||
retrieved_sha256 = sha256(weights_path.read_bytes()).hexdigest().lower()
|
||||
assert retrieved_sha256 == expected_sha256, f"expected {expected_sha256}, got {retrieved_sha256}"
|
||||
|
||||
return load_tensors(weights_path)
|
||||
|
||||
|
||||
def test_add_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||
assert "pokemon-lora" in manager.names
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||
assert "already exists" in str(exc.value)
|
||||
|
||||
|
||||
def test_add_multiple_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", pokemon_lora_weights)
|
||||
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
|
||||
assert "pokemon-lora" in manager.names
|
||||
assert "pokemon-lora2" in manager.names
|
||||
|
||||
|
||||
def test_remove_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", pokemon_lora_weights)
|
||||
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
|
||||
manager.remove_loras("pokemon-lora")
|
||||
assert "pokemon-lora" not in manager.names
|
||||
assert "pokemon-lora2" in manager.names
|
||||
|
||||
manager.remove_loras("pokemon-lora2")
|
||||
assert "pokemon-lora2" not in manager.names
|
||||
assert len(manager.names) == 0
|
||||
|
||||
|
||||
def test_remove_all(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", pokemon_lora_weights)
|
||||
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
|
||||
manager.remove_all()
|
||||
assert len(manager.names) == 0
|
||||
|
||||
|
||||
def test_get_lora(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
|
||||
|
||||
|
||||
def test_get_scale(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
|
||||
assert manager.get_scale("pokemon-lora") == 0.4
|
||||
|
||||
|
||||
def test_names(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
assert manager.names == []
|
||||
|
||||
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||
assert manager.names == ["pokemon-lora"]
|
||||
|
||||
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights)
|
||||
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
|
||||
|
||||
|
||||
def test_scales(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||
assert manager.scales == {}
|
||||
|
||||
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
|
||||
assert manager.scales == {"pokemon-lora": 0.4}
|
||||
|
||||
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights, scale=0.5)
|
||||
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}
|
|
@ -1,91 +0,0 @@
|
|||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import load_tensors
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager() -> SDLoraManager:
|
||||
target = StableDiffusion_1()
|
||||
return SDLoraManager(target)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
|
||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||
|
||||
if not weights_path.is_file():
|
||||
warn(f"could not find weights at {weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
return load_tensors(weights_path)
|
||||
|
||||
|
||||
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", tensors=weights)
|
||||
assert "pokemon-lora" in manager.names
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
manager.add_loras("pokemon-lora", tensors=weights)
|
||||
assert "already exists" in str(exc.value)
|
||||
|
||||
|
||||
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", weights)
|
||||
manager.add_loras("pokemon-lora2", weights)
|
||||
assert "pokemon-lora" in manager.names
|
||||
assert "pokemon-lora2" in manager.names
|
||||
|
||||
|
||||
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", weights)
|
||||
manager.add_loras("pokemon-lora2", weights)
|
||||
manager.remove_loras("pokemon-lora")
|
||||
assert "pokemon-lora" not in manager.names
|
||||
assert "pokemon-lora2" in manager.names
|
||||
|
||||
manager.remove_loras("pokemon-lora2")
|
||||
assert "pokemon-lora2" not in manager.names
|
||||
assert len(manager.names) == 0
|
||||
|
||||
|
||||
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", weights)
|
||||
manager.add_loras("pokemon-lora2", weights)
|
||||
manager.remove_all()
|
||||
assert len(manager.names) == 0
|
||||
|
||||
|
||||
def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", tensors=weights)
|
||||
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
|
||||
|
||||
|
||||
def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
|
||||
assert manager.get_scale("pokemon-lora") == 0.4
|
||||
|
||||
|
||||
def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
assert manager.names == []
|
||||
|
||||
manager.add_loras("pokemon-lora", tensors=weights)
|
||||
assert manager.names == ["pokemon-lora"]
|
||||
|
||||
manager.add_loras("pokemon-lora2", tensors=weights)
|
||||
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
|
||||
|
||||
|
||||
def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
||||
assert manager.scales == {}
|
||||
|
||||
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
|
||||
assert manager.scales == {"pokemon-lora": 0.4}
|
||||
|
||||
manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5)
|
||||
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}
|
Loading…
Reference in a new issue