mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
471ef91d1c
PyTorch chose to make it Any because they expect its users' code to be "highly dynamic": https://github.com/pytorch/pytorch/pull/104321 It is not the case for us, in Refiners having untyped code goes contrary to one of our core principles. Note that there is currently an open PR in PyTorch to return `Module | Tensor`, but in practice this is not always correct either: https://github.com/pytorch/pytorch/pull/115074 I also moved Residuals-related code from SD1 to latent_diffusion because SDXL should not depend on SD1.
122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from refiners.fluxion import layers as fl
|
|
from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter
|
|
|
|
|
|
@pytest.fixture
|
|
def lora() -> LinearLora:
|
|
return LinearLora("test", in_features=320, out_features=128, rank=16)
|
|
|
|
|
|
@pytest.fixture
|
|
def conv_lora() -> Lora:
|
|
return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4)
|
|
|
|
|
|
def test_properties(lora: LinearLora, conv_lora: Lora) -> None:
|
|
assert lora.name == "test"
|
|
assert lora.rank == lora.down.out_features == lora.up.in_features == 16
|
|
assert lora.scale == 1.0
|
|
assert lora.in_features == lora.down.in_features == 320
|
|
assert lora.out_features == lora.up.out_features == 128
|
|
|
|
assert conv_lora.name == "conv_test"
|
|
assert conv_lora.rank == conv_lora.down.out_channels == conv_lora.up.in_channels == 4
|
|
assert conv_lora.scale == 1.0
|
|
assert conv_lora.in_channels == conv_lora.down.in_channels == 16
|
|
assert conv_lora.out_channels == conv_lora.up.out_channels == 8
|
|
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
|
assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1)
|
|
# padding is set so the spatial dimensions are preserved
|
|
assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1)
|
|
|
|
|
|
def test_scale_setter(lora: LinearLora) -> None:
|
|
lora.scale = 2.0
|
|
assert lora.scale == 2.0
|
|
assert lora.ensure_find(fl.Multiply).scale == 2.0
|
|
|
|
|
|
def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None:
|
|
assert isinstance(lora.down, fl.Linear) and isinstance(lora.up, fl.Linear)
|
|
new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight)
|
|
x = torch.randn(1, 320)
|
|
assert torch.allclose(lora(x), new_lora(x))
|
|
|
|
assert isinstance(conv_lora.down, fl.Conv2d) and isinstance(conv_lora.up, fl.Conv2d)
|
|
new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight)
|
|
x = torch.randn(1, 16, 64, 64)
|
|
assert torch.allclose(conv_lora(x), new_conv_lora(x))
|
|
|
|
|
|
def test_from_dict() -> None:
|
|
state_dict = {
|
|
"down.weight": torch.randn(320, 128),
|
|
"up.weight": torch.randn(128, 320),
|
|
"this.is_not_used.alpha": torch.randn(1, 320),
|
|
"probably.a.conv.down.weight": torch.randn(4, 16, 3, 3),
|
|
"probably.a.conv.up.weight": torch.randn(8, 4, 1, 1),
|
|
}
|
|
loras = Lora.from_dict("test", state_dict=state_dict)
|
|
assert len(loras) == 2
|
|
linear_lora, conv_lora = loras.values()
|
|
assert isinstance(linear_lora, LinearLora)
|
|
assert isinstance(conv_lora, Conv2dLora)
|
|
assert linear_lora.name == "test"
|
|
assert conv_lora.name == "test"
|
|
|
|
|
|
@pytest.fixture
|
|
def lora_adapter() -> LoraAdapter:
|
|
target = fl.Linear(320, 128)
|
|
lora1 = LinearLora("test1", in_features=320, out_features=128, rank=16, scale=2.0)
|
|
lora2 = LinearLora("test2", in_features=320, out_features=128, rank=16, scale=-1.0)
|
|
return LoraAdapter(target, lora1, lora2)
|
|
|
|
|
|
def test_names(lora_adapter: LoraAdapter) -> None:
|
|
assert set(lora_adapter.names) == {"test1", "test2"}
|
|
|
|
|
|
def test_loras(lora_adapter: LoraAdapter) -> None:
|
|
assert set(lora_adapter.loras.keys()) == {"test1", "test2"}
|
|
|
|
|
|
def test_scales(lora_adapter: LoraAdapter) -> None:
|
|
assert set(lora_adapter.scales.keys()) == {"test1", "test2"}
|
|
assert lora_adapter.scales["test1"] == 2.0
|
|
assert lora_adapter.scales["test2"] == -1.0
|
|
|
|
|
|
def test_scale_setter_lora_adapter(lora_adapter: LoraAdapter) -> None:
|
|
lora_adapter.scale = {"test1": 0.0, "test2": 3.0}
|
|
assert lora_adapter.scales == {"test1": 0.0, "test2": 3.0}
|
|
|
|
|
|
def test_add_lora(lora_adapter: LoraAdapter) -> None:
|
|
lora3 = LinearLora("test3", in_features=320, out_features=128, rank=16)
|
|
lora_adapter.add_lora(lora3)
|
|
assert "test3" in lora_adapter.names
|
|
|
|
|
|
def test_remove_lora(lora_adapter: LoraAdapter) -> None:
|
|
lora_adapter.remove_lora("test1")
|
|
assert "test1" not in lora_adapter.names
|
|
|
|
|
|
def test_add_existing_lora(lora_adapter: LoraAdapter) -> None:
|
|
lora3 = LinearLora("test1", in_features=320, out_features=128, rank=16)
|
|
with pytest.raises(AssertionError):
|
|
lora_adapter.add_lora(lora3)
|
|
|
|
|
|
def test_remove_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
|
|
assert lora_adapter.remove_lora("test3") is None
|
|
|
|
|
|
def test_set_scale_for_nonexistent_lora(lora_adapter: LoraAdapter) -> None:
|
|
with pytest.raises(KeyError):
|
|
lora_adapter.scale = {"test3": 2.0}
|