mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-09-19 19:05:28 +00:00
move some tests into the adapters test folder
This commit is contained in:
parent
a51d695523
commit
60f6f62056
|
@ -1,5 +1,3 @@
|
||||||
from typing import Iterator
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -10,32 +8,46 @@ from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[True, False])
|
@pytest.fixture(scope="module")
|
||||||
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]:
|
def unet(
|
||||||
xl: bool = request.param
|
refiners_unet: SD1UNet | SDXLUNet,
|
||||||
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4)
|
) -> SD1UNet | SDXLUNet:
|
||||||
yield unet
|
return refiners_unet
|
||||||
|
|
||||||
|
|
||||||
def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
|
def test_inject_eject_freeu(
|
||||||
|
unet: SD1UNet | SDXLUNet,
|
||||||
|
) -> None:
|
||||||
|
initial_repr = repr(unet)
|
||||||
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
|
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
|
||||||
|
|
||||||
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
|
assert unet.parent is None
|
||||||
|
assert unet.find(FreeUResidualConcatenator) is None
|
||||||
with pytest.raises(AssertionError) as exc:
|
assert repr(unet) == initial_repr
|
||||||
freeu.eject()
|
|
||||||
assert "could not find" in str(exc.value)
|
|
||||||
|
|
||||||
freeu.inject()
|
freeu.inject()
|
||||||
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2
|
assert unet.parent is not None
|
||||||
|
assert unet.find(FreeUResidualConcatenator) is not None
|
||||||
|
assert repr(unet) != initial_repr
|
||||||
|
|
||||||
freeu.eject()
|
freeu.eject()
|
||||||
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
|
assert unet.parent is None
|
||||||
|
assert unet.find(FreeUResidualConcatenator) is None
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
|
||||||
|
freeu.inject()
|
||||||
|
assert unet.parent is not None
|
||||||
|
assert unet.find(FreeUResidualConcatenator) is not None
|
||||||
|
assert repr(unet) != initial_repr
|
||||||
|
|
||||||
|
freeu.eject()
|
||||||
|
assert unet.parent is None
|
||||||
|
assert unet.find(FreeUResidualConcatenator) is None
|
||||||
|
assert repr(unet) == initial_repr
|
||||||
|
|
||||||
|
|
||||||
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
num_blocks = len(unet.layer("UpBlocks", Chain))
|
num_blocks = len(unet.layer("UpBlocks", Chain))
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
|
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
|
||||||
|
|
||||||
|
@ -43,15 +55,16 @@ def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
|
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
|
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2, 1.2], skip_scales=[0.9, 0.9])
|
||||||
|
|
||||||
|
|
||||||
def test_freeu_identity_scales() -> None:
|
def test_freeu_identity_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
text_embedding = torch.randn(1, 77, 768)
|
text_embedding = torch.randn(1, 77, 768, dtype=unet.dtype, device=unet.device)
|
||||||
timestep = torch.randint(0, 999, size=(1, 1))
|
timestep = torch.randint(0, 999, size=(1, 1), device=unet.device)
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32, dtype=unet.dtype, device=unet.device)
|
||||||
|
|
||||||
unet = SD1UNet(in_channels=4)
|
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||||
|
|
||||||
with no_grad():
|
with no_grad():
|
||||||
|
@ -65,5 +78,7 @@ def test_freeu_identity_scales() -> None:
|
||||||
unet.set_timestep(timestep=timestep)
|
unet.set_timestep(timestep=timestep)
|
||||||
y_2 = unet(x.clone())
|
y_2 = unet(x.clone())
|
||||||
|
|
||||||
|
freeu.eject()
|
||||||
|
|
||||||
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
|
# The FFT -> inverse FFT sequence (skip features) introduces small numerical differences
|
||||||
assert torch.allclose(y_1, y_2, atol=1e-5)
|
assert torch.allclose(y_1, y_2, atol=1e-5)
|
107
tests/adapters/test_lora_manager.py
Normal file
107
tests/adapters/test_lora_manager.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
from hashlib import sha256
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download # type: ignore
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import load_tensors
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(refiners_sd15: StableDiffusion_1) -> SDLoraManager:
|
||||||
|
return SDLoraManager(refiners_sd15)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pokemon_lora_weights(
|
||||||
|
test_weights_path: Path,
|
||||||
|
use_local_weights: bool,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
if use_local_weights:
|
||||||
|
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
||||||
|
if not weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
else:
|
||||||
|
weights_path = Path(
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="pcuenq/pokemon-lora",
|
||||||
|
filename="pytorch_lora_weights.bin",
|
||||||
|
revision="bc3cb5256ebc303457acab170ca6219a66dd31f5",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_sha256 = "f712fcfb6618da14d25a4f3e0c9460a878fc2417e2df95cdd683a73f71b50384"
|
||||||
|
retrieved_sha256 = sha256(weights_path.read_bytes()).hexdigest().lower()
|
||||||
|
assert retrieved_sha256 == expected_sha256, f"expected {expected_sha256}, got {retrieved_sha256}"
|
||||||
|
|
||||||
|
return load_tensors(weights_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||||
|
assert "pokemon-lora" in manager.names
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as exc:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||||
|
assert "already exists" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_multiple_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", pokemon_lora_weights)
|
||||||
|
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
|
||||||
|
assert "pokemon-lora" in manager.names
|
||||||
|
assert "pokemon-lora2" in manager.names
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_loras(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", pokemon_lora_weights)
|
||||||
|
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
|
||||||
|
manager.remove_loras("pokemon-lora")
|
||||||
|
assert "pokemon-lora" not in manager.names
|
||||||
|
assert "pokemon-lora2" in manager.names
|
||||||
|
|
||||||
|
manager.remove_loras("pokemon-lora2")
|
||||||
|
assert "pokemon-lora2" not in manager.names
|
||||||
|
assert len(manager.names) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_all(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", pokemon_lora_weights)
|
||||||
|
manager.add_loras("pokemon-lora2", pokemon_lora_weights)
|
||||||
|
manager.remove_all()
|
||||||
|
assert len(manager.names) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lora(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||||
|
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_scale(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
|
||||||
|
assert manager.get_scale("pokemon-lora") == 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def test_names(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
assert manager.names == []
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights)
|
||||||
|
assert manager.names == ["pokemon-lora"]
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights)
|
||||||
|
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_scales(manager: SDLoraManager, pokemon_lora_weights: dict[str, torch.Tensor]) -> None:
|
||||||
|
assert manager.scales == {}
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora", tensors=pokemon_lora_weights, scale=0.4)
|
||||||
|
assert manager.scales == {"pokemon-lora": 0.4}
|
||||||
|
|
||||||
|
manager.add_loras("pokemon-lora2", tensors=pokemon_lora_weights, scale=0.5)
|
||||||
|
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}
|
|
@ -1,91 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_tensors
|
|
||||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
|
||||||
from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def manager() -> SDLoraManager:
|
|
||||||
target = StableDiffusion_1()
|
|
||||||
return SDLoraManager(target)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def weights(test_weights_path: Path) -> dict[str, torch.Tensor]:
|
|
||||||
weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin"
|
|
||||||
|
|
||||||
if not weights_path.is_file():
|
|
||||||
warn(f"could not find weights at {weights_path}, skipping")
|
|
||||||
pytest.skip(allow_module_level=True)
|
|
||||||
|
|
||||||
return load_tensors(weights_path)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
manager.add_loras("pokemon-lora", tensors=weights)
|
|
||||||
assert "pokemon-lora" in manager.names
|
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as exc:
|
|
||||||
manager.add_loras("pokemon-lora", tensors=weights)
|
|
||||||
assert "already exists" in str(exc.value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
manager.add_loras("pokemon-lora", weights)
|
|
||||||
manager.add_loras("pokemon-lora2", weights)
|
|
||||||
assert "pokemon-lora" in manager.names
|
|
||||||
assert "pokemon-lora2" in manager.names
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
manager.add_loras("pokemon-lora", weights)
|
|
||||||
manager.add_loras("pokemon-lora2", weights)
|
|
||||||
manager.remove_loras("pokemon-lora")
|
|
||||||
assert "pokemon-lora" not in manager.names
|
|
||||||
assert "pokemon-lora2" in manager.names
|
|
||||||
|
|
||||||
manager.remove_loras("pokemon-lora2")
|
|
||||||
assert "pokemon-lora2" not in manager.names
|
|
||||||
assert len(manager.names) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
manager.add_loras("pokemon-lora", weights)
|
|
||||||
manager.add_loras("pokemon-lora2", weights)
|
|
||||||
manager.remove_all()
|
|
||||||
assert len(manager.names) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
manager.add_loras("pokemon-lora", tensors=weights)
|
|
||||||
assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
|
|
||||||
assert manager.get_scale("pokemon-lora") == 0.4
|
|
||||||
|
|
||||||
|
|
||||||
def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
assert manager.names == []
|
|
||||||
|
|
||||||
manager.add_loras("pokemon-lora", tensors=weights)
|
|
||||||
assert manager.names == ["pokemon-lora"]
|
|
||||||
|
|
||||||
manager.add_loras("pokemon-lora2", tensors=weights)
|
|
||||||
assert set(manager.names) == set(["pokemon-lora", "pokemon-lora2"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None:
|
|
||||||
assert manager.scales == {}
|
|
||||||
|
|
||||||
manager.add_loras("pokemon-lora", tensors=weights, scale=0.4)
|
|
||||||
assert manager.scales == {"pokemon-lora": 0.4}
|
|
||||||
|
|
||||||
manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5)
|
|
||||||
assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}
|
|
Loading…
Reference in a new issue