mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-10 07:21:59 +00:00
fix LoRAs on Self target
This commit is contained in:
parent
3565a4127f
commit
2ad26a06b0
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from refiners.adapters.lora import LoraAdapter, Lora
|
from refiners.adapters.lora import LoraAdapter, Lora
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
@ -28,12 +28,8 @@ class LoraConfig(BaseModel):
|
||||||
lda_targets: list[LoraTarget]
|
lda_targets: list[LoraTarget]
|
||||||
|
|
||||||
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
|
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
|
||||||
for layer in module.layers(layer_type=target.get_class()):
|
for linear, parent in lora_targets(module, target):
|
||||||
for linear, parent in layer.walk(fl.Linear):
|
adapter = LoraAdapter(target=linear, rank=self.rank)
|
||||||
adapter = LoraAdapter(
|
|
||||||
target=linear,
|
|
||||||
rank=self.rank,
|
|
||||||
)
|
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
for linear in adapter.Lora.layers(fl.Linear):
|
for linear in adapter.Lora.layers(fl.Linear):
|
||||||
linear.requires_grad_(requires_grad=True)
|
linear.requires_grad_(requires_grad=True)
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
|
||||||
from torch import Tensor, device as Device
|
from torch import Tensor, device as Device
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
@ -42,14 +44,16 @@ def get_lora_rank(weights: list[Tensor]) -> int:
|
||||||
return ranks.pop()
|
return ranks.pop()
|
||||||
|
|
||||||
|
|
||||||
|
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||||
|
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class())
|
||||||
|
for layer in it:
|
||||||
|
for t in layer.walk(fl.Linear):
|
||||||
|
yield t
|
||||||
|
|
||||||
|
|
||||||
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
|
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
|
||||||
for layer in module.layers(layer_type=target.get_class()):
|
for linear, parent in lora_targets(module, target):
|
||||||
for linear, parent in layer.walk(fl.Linear):
|
adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
|
||||||
adapter = LoraAdapter(
|
|
||||||
target=linear,
|
|
||||||
rank=rank,
|
|
||||||
scale=scale,
|
|
||||||
)
|
|
||||||
adapter.inject(parent)
|
adapter.inject(parent)
|
||||||
|
|
||||||
|
|
||||||
|
|
16
tests/foundationals/latent_diffusion/test_lora.py
Normal file
16
tests/foundationals/latent_diffusion/test_lora.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from refiners.adapters.lora import Lora
|
||||||
|
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_target_self() -> None:
|
||||||
|
chain = fl.Chain(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
fl.Linear(in_features=1, out_features=1),
|
||||||
|
),
|
||||||
|
fl.Linear(in_features=1, out_features=2),
|
||||||
|
)
|
||||||
|
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
|
||||||
|
|
||||||
|
assert len(list(chain.layers(Lora))) == 3
|
Loading…
Reference in a new issue