refiners/tests/adapters/test_lora.py

30 lines
822 B
Python
Raw Normal View History

2023-08-04 13:28:41 +00:00
from refiners.adapters.lora import Lora, LoraAdapter
from torch import randn, allclose
import refiners.fluxion.layers as fl
def test_lora() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
x = randn(1, 1)
y = chain(x)
lora_adapter = LoraAdapter(chain.Chain.Linear_1)
lora_adapter.inject(chain.Chain)
assert isinstance(lora_adapter[1], Lora)
assert allclose(input=chain(x), other=y)
assert lora_adapter.parent == chain.Chain
lora_adapter.eject()
assert isinstance(chain.Chain[0], fl.Linear)
assert len(chain) == 2
lora_adapter.inject(chain.Chain)
assert isinstance(chain.Chain[0], LoraAdapter)