allow lora_targets to take a list of targets as input

This commit is contained in:
Pierre Chapuis 2023-09-01 10:29:19 +02:00
parent 92cdf19eae
commit eba0c33001
2 changed files with 11 additions and 3 deletions

View file

@ -82,7 +82,7 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets") model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
adapter = LoraAdapter[type(model)]( adapter = LoraAdapter[type(model)](
model, model,
sub_targets=[x for target in model_targets for x in lora_targets(model, target)], sub_targets=lora_targets(model, model_targets),
rank=lora_config.rank, rank=lora_config.rank,
) )
for sub_adapter, _ in adapter.sub_adapters: for sub_adapter, _ in adapter.sub_adapters:

View file

@ -47,7 +47,15 @@ class LoraTarget(str, Enum):
return TransformerLayer return TransformerLayer
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]: def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class() lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
if isinstance(module, SD1UNet): if isinstance(module, SD1UNet):
@ -100,7 +108,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
self.sub_adapters.append( self.sub_adapters.append(
LoraAdapter[type(model)]( LoraAdapter[type(model)](
model, model,
sub_targets=[x for target in model_targets for x in lora_targets(model, target)], sub_targets=lora_targets(model, model_targets),
scale=scale, scale=scale,
weights=lora_weights, weights=lora_weights,
) )