mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix LoRAs on Self target
This commit is contained in:
parent
3565a4127f
commit
2ad26a06b0
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||
from loguru import logger
|
||||
from refiners.adapters.lora import LoraAdapter, Lora
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||
import refiners.fluxion.layers as fl
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -28,12 +28,8 @@ class LoraConfig(BaseModel):
|
|||
lda_targets: list[LoraTarget]
|
||||
|
||||
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
|
||||
for layer in module.layers(layer_type=target.get_class()):
|
||||
for linear, parent in layer.walk(fl.Linear):
|
||||
adapter = LoraAdapter(
|
||||
target=linear,
|
||||
rank=self.rank,
|
||||
)
|
||||
for linear, parent in lora_targets(module, target):
|
||||
adapter = LoraAdapter(target=linear, rank=self.rank)
|
||||
adapter.inject(parent)
|
||||
for linear in adapter.Lora.layers(fl.Linear):
|
||||
linear.requires_grad_(requires_grad=True)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
|
||||
from torch import Tensor, device as Device
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
@ -42,14 +44,16 @@ def get_lora_rank(weights: list[Tensor]) -> int:
|
|||
return ranks.pop()
|
||||
|
||||
|
||||
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class())
|
||||
for layer in it:
|
||||
for t in layer.walk(fl.Linear):
|
||||
yield t
|
||||
|
||||
|
||||
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
|
||||
for layer in module.layers(layer_type=target.get_class()):
|
||||
for linear, parent in layer.walk(fl.Linear):
|
||||
adapter = LoraAdapter(
|
||||
target=linear,
|
||||
rank=rank,
|
||||
scale=scale,
|
||||
)
|
||||
for linear, parent in lora_targets(module, target):
|
||||
adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
|
||||
adapter.inject(parent)
|
||||
|
||||
|
||||
|
|
16
tests/foundationals/latent_diffusion/test_lora.py
Normal file
16
tests/foundationals/latent_diffusion/test_lora.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from refiners.adapters.lora import Lora
|
||||
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
def test_lora_target_self() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=2),
|
||||
)
|
||||
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
|
||||
|
||||
assert len(list(chain.layers(Lora))) == 3
|
Loading…
Reference in a new issue