From eba0c330011ba7c1d989dba5c292081353ac41cc Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 1 Sep 2023 10:29:19 +0200 Subject: [PATCH] allow lora_targets to take a list of targets as input --- scripts/training/finetune-ldm-lora.py | 2 +- src/refiners/foundationals/latent_diffusion/lora.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index 40ad8a5..11d02e4 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -82,7 +82,7 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]): model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets") adapter = LoraAdapter[type(model)]( model, - sub_targets=[x for target in model_targets for x in lora_targets(model, target)], + sub_targets=lora_targets(model, model_targets), rank=lora_config.rank, ) for sub_adapter, _ in adapter.sub_adapters: diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 910da40..d5ddd25 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -47,7 +47,15 @@ class LoraTarget(str, Enum): return TransformerLayer -def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]: +def lora_targets( + module: fl.Chain, + target: LoraTarget | list[LoraTarget], +) -> Iterator[tuple[fl.Linear, fl.Chain]]: + if isinstance(target, list): + for t in target: + yield from lora_targets(module, t) + return + lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class() if isinstance(module, SD1UNet): @@ -100,7 +108,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): self.sub_adapters.append( LoraAdapter[type(model)]( model, - sub_targets=[x for target in model_targets for x in lora_targets(model, target)], + sub_targets=lora_targets(model, model_targets), scale=scale, weights=lora_weights, )