mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
30 lines
822 B
Python
30 lines
822 B
Python
|
from refiners.adapters.lora import Lora, LoraAdapter
|
||
|
from torch import randn, allclose
|
||
|
import refiners.fluxion.layers as fl
|
||
|
|
||
|
|
||
|
def test_lora() -> None:
|
||
|
chain = fl.Chain(
|
||
|
fl.Chain(
|
||
|
fl.Linear(in_features=1, out_features=1),
|
||
|
fl.Linear(in_features=1, out_features=1),
|
||
|
),
|
||
|
fl.Linear(in_features=1, out_features=2),
|
||
|
)
|
||
|
x = randn(1, 1)
|
||
|
y = chain(x)
|
||
|
|
||
|
lora_adapter = LoraAdapter(chain.Chain.Linear_1)
|
||
|
lora_adapter.inject(chain.Chain)
|
||
|
|
||
|
assert isinstance(lora_adapter[1], Lora)
|
||
|
assert allclose(input=chain(x), other=y)
|
||
|
assert lora_adapter.parent == chain.Chain
|
||
|
|
||
|
lora_adapter.eject()
|
||
|
assert isinstance(chain.Chain[0], fl.Linear)
|
||
|
assert len(chain) == 2
|
||
|
|
||
|
lora_adapter.inject(chain.Chain)
|
||
|
assert isinstance(chain.Chain[0], LoraAdapter)
|