fix LoRAs on Self target

This commit is contained in:
Pierre Chapuis 2023-08-22 18:03:26 +02:00
parent 3565a4127f
commit 2ad26a06b0
3 changed files with 34 additions and 18 deletions

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel
from loguru import logger from loguru import logger
from refiners.adapters.lora import LoraAdapter, Lora from refiners.adapters.lora import LoraAdapter, Lora
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -28,12 +28,8 @@ class LoraConfig(BaseModel):
lda_targets: list[LoraTarget] lda_targets: list[LoraTarget]
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None: def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
for layer in module.layers(layer_type=target.get_class()): for linear, parent in lora_targets(module, target):
for linear, parent in layer.walk(fl.Linear): adapter = LoraAdapter(target=linear, rank=self.rank)
adapter = LoraAdapter(
target=linear,
rank=self.rank,
)
adapter.inject(parent) adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear): for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True) linear.requires_grad_(requires_grad=True)

View file

@ -1,5 +1,7 @@
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Iterator
from torch import Tensor, device as Device from torch import Tensor, device as Device
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
@ -42,14 +44,16 @@ def get_lora_rank(weights: list[Tensor]) -> int:
return ranks.pop() return ranks.pop()
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class())
for layer in it:
for t in layer.walk(fl.Linear):
yield t
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None: def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
for layer in module.layers(layer_type=target.get_class()): for linear, parent in lora_targets(module, target):
for linear, parent in layer.walk(fl.Linear): adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
adapter = LoraAdapter(
target=linear,
rank=rank,
scale=scale,
)
adapter.inject(parent) adapter.inject(parent)

View file

@ -0,0 +1,16 @@
from refiners.adapters.lora import Lora
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
import refiners.fluxion.layers as fl
def test_lora_target_self() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
assert len(list(chain.layers(Lora))) == 3