allow lora_targets to take a list of targets as input

This commit is contained in:
Pierre Chapuis 2023-09-01 10:29:19 +02:00
parent 92cdf19eae
commit eba0c33001
2 changed files with 11 additions and 3 deletions

View file

@ -82,7 +82,7 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
adapter = LoraAdapter[type(model)](
model,
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
sub_targets=lora_targets(model, model_targets),
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:

View file

@ -47,7 +47,15 @@ class LoraTarget(str, Enum):
return TransformerLayer
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
if isinstance(module, SD1UNet):
@ -100,7 +108,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
sub_targets=lora_targets(model, model_targets),
scale=scale,
weights=lora_weights,
)