mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
allow lora_targets to take a list of targets as input
This commit is contained in:
parent
92cdf19eae
commit
eba0c33001
|
@ -82,7 +82,7 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||||
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
||||||
adapter = LoraAdapter[type(model)](
|
adapter = LoraAdapter[type(model)](
|
||||||
model,
|
model,
|
||||||
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
sub_targets=lora_targets(model, model_targets),
|
||||||
rank=lora_config.rank,
|
rank=lora_config.rank,
|
||||||
)
|
)
|
||||||
for sub_adapter, _ in adapter.sub_adapters:
|
for sub_adapter, _ in adapter.sub_adapters:
|
||||||
|
|
|
@ -47,7 +47,15 @@ class LoraTarget(str, Enum):
|
||||||
return TransformerLayer
|
return TransformerLayer
|
||||||
|
|
||||||
|
|
||||||
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
def lora_targets(
|
||||||
|
module: fl.Chain,
|
||||||
|
target: LoraTarget | list[LoraTarget],
|
||||||
|
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||||
|
if isinstance(target, list):
|
||||||
|
for t in target:
|
||||||
|
yield from lora_targets(module, t)
|
||||||
|
return
|
||||||
|
|
||||||
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
|
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
|
||||||
|
|
||||||
if isinstance(module, SD1UNet):
|
if isinstance(module, SD1UNet):
|
||||||
|
@ -100,7 +108,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||||
self.sub_adapters.append(
|
self.sub_adapters.append(
|
||||||
LoraAdapter[type(model)](
|
LoraAdapter[type(model)](
|
||||||
model,
|
model,
|
||||||
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
sub_targets=lora_targets(model, model_targets),
|
||||||
scale=scale,
|
scale=scale,
|
||||||
weights=lora_weights,
|
weights=lora_weights,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue