fix LoRAs on Self target

This commit is contained in:
Pierre Chapuis 2023-08-22 18:03:26 +02:00
parent 3565a4127f
commit 2ad26a06b0
3 changed files with 34 additions and 18 deletions

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
from loguru import logger
from refiners.adapters.lora import LoraAdapter, Lora
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
@ -28,15 +28,11 @@ class LoraConfig(BaseModel):
lda_targets: list[LoraTarget]
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
for layer in module.layers(layer_type=target.get_class()):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(
target=linear,
rank=self.rank,
)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=self.rank)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):

View file

@ -1,5 +1,7 @@
from enum import Enum
from pathlib import Path
from typing import Iterator
from torch import Tensor, device as Device
from torch.nn import Parameter as TorchParameter
@ -42,15 +44,17 @@ def get_lora_rank(weights: list[Tensor]) -> int:
return ranks.pop()
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class())
for layer in it:
for t in layer.walk(fl.Linear):
yield t
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
for layer in module.layers(layer_type=target.get_class()):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(
target=linear,
rank=rank,
scale=scale,
)
adapter.inject(parent)
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
adapter.inject(parent)
class LoraWeights:

View file

@ -0,0 +1,16 @@
from refiners.adapters.lora import Lora
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
import refiners.fluxion.layers as fl
def test_lora_target_self() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
assert len(list(chain.layers(Lora))) == 3