diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index abffc46..fb6e030 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from loguru import logger from refiners.adapters.lora import LoraAdapter, Lora from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion.lora import LoraTarget +from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets import refiners.fluxion.layers as fl from torch import Tensor from torch.utils.data import Dataset @@ -28,15 +28,11 @@ class LoraConfig(BaseModel): lda_targets: list[LoraTarget] def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None: - for layer in module.layers(layer_type=target.get_class()): - for linear, parent in layer.walk(fl.Linear): - adapter = LoraAdapter( - target=linear, - rank=self.rank, - ) - adapter.inject(parent) - for linear in adapter.Lora.layers(fl.Linear): - linear.requires_grad_(requires_grad=True) + for linear, parent in lora_targets(module, target): + adapter = LoraAdapter(target=linear, rank=self.rank) + adapter.inject(parent) + for linear in adapter.Lora.layers(fl.Linear): + linear.requires_grad_(requires_grad=True) class TriggerPhraseDataset(TextEmbeddingLatentsDataset): diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index ac29631..2da54e1 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,5 +1,7 @@ from enum import Enum from pathlib import Path +from typing import Iterator + from torch import Tensor, device as Device from torch.nn import Parameter as TorchParameter @@ -42,15 +44,17 @@ def get_lora_rank(weights: list[Tensor]) -> int: return ranks.pop() +def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]: + it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class()) + for layer in it: + for t in layer.walk(fl.Linear): + yield t + + def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None: - for layer in module.layers(layer_type=target.get_class()): - for linear, parent in layer.walk(fl.Linear): - adapter = LoraAdapter( - target=linear, - rank=rank, - scale=scale, - ) - adapter.inject(parent) + for linear, parent in lora_targets(module, target): + adapter = LoraAdapter(target=linear, rank=rank, scale=scale) + adapter.inject(parent) class LoraWeights: diff --git a/tests/foundationals/latent_diffusion/test_lora.py b/tests/foundationals/latent_diffusion/test_lora.py new file mode 100644 index 0000000..4df6140 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_lora.py @@ -0,0 +1,16 @@ +from refiners.adapters.lora import Lora +from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget +import refiners.fluxion.layers as fl + + +def test_lora_target_self() -> None: + chain = fl.Chain( + fl.Chain( + fl.Linear(in_features=1, out_features=1), + fl.Linear(in_features=1, out_features=1), + ), + fl.Linear(in_features=1, out_features=2), + ) + apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0) + + assert len(list(chain.layers(Lora))) == 3