mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
allow lora_targets to take a list of targets as input
This commit is contained in:
parent
92cdf19eae
commit
eba0c33001
|
@ -82,7 +82,7 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
|||
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
||||
adapter = LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
||||
sub_targets=lora_targets(model, model_targets),
|
||||
rank=lora_config.rank,
|
||||
)
|
||||
for sub_adapter, _ in adapter.sub_adapters:
|
||||
|
|
|
@ -47,7 +47,15 @@ class LoraTarget(str, Enum):
|
|||
return TransformerLayer
|
||||
|
||||
|
||||
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
def lora_targets(
|
||||
module: fl.Chain,
|
||||
target: LoraTarget | list[LoraTarget],
|
||||
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
if isinstance(target, list):
|
||||
for t in target:
|
||||
yield from lora_targets(module, t)
|
||||
return
|
||||
|
||||
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
|
||||
|
||||
if isinstance(module, SD1UNet):
|
||||
|
@ -100,7 +108,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
|||
self.sub_adapters.append(
|
||||
LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
||||
sub_targets=lora_targets(model, model_targets),
|
||||
scale=scale,
|
||||
weights=lora_weights,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue