ability to get LoRA weights in SDLoraManager

This commit is contained in:
Pierre Chapuis 2024-03-06 09:53:32 +01:00
parent fb90b00e75
commit cd5fa97c20

View file

@ -92,6 +92,25 @@ class SDLoraManager:
# set the scale of the LoRA
self.set_scale(name, scale)
def _get_lora_weights(self, base: fl.Chain, name: str, accum: dict[str, Tensor]) -> None:
prev_parent: fl.Chain | None = None
n = 0
for lora_adapter, parent in base.walk(LoraAdapter):
lora = next((l for l in lora_adapter.lora_layers if l.name == name), None)
if lora is None:
continue
n = (parent == prev_parent) and n + 1 or 1
pfx = f"{parent.get_path()}.{n}.{lora_adapter.target.__class__.__name__}"
accum[f"{pfx}.down.weight"] = lora.down.weight
accum[f"{pfx}.up.weight"] = lora.up.weight
prev_parent = parent
def get_lora_weights(self, name: str) -> dict[str, Tensor]:
r: dict[str, Tensor] = {}
self._get_lora_weights(self.unet, name, r)
self._get_lora_weights(self.clip_text_encoder, name, r)
return r
def add_loras_to_text_encoder(
self,
loras: dict[str, Lora[Any]],