mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
ability to get LoRA weights in SDLoraManager
This commit is contained in:
parent
fb90b00e75
commit
cd5fa97c20
|
@ -92,6 +92,25 @@ class SDLoraManager:
|
||||||
# set the scale of the LoRA
|
# set the scale of the LoRA
|
||||||
self.set_scale(name, scale)
|
self.set_scale(name, scale)
|
||||||
|
|
||||||
|
def _get_lora_weights(self, base: fl.Chain, name: str, accum: dict[str, Tensor]) -> None:
|
||||||
|
prev_parent: fl.Chain | None = None
|
||||||
|
n = 0
|
||||||
|
for lora_adapter, parent in base.walk(LoraAdapter):
|
||||||
|
lora = next((l for l in lora_adapter.lora_layers if l.name == name), None)
|
||||||
|
if lora is None:
|
||||||
|
continue
|
||||||
|
n = (parent == prev_parent) and n + 1 or 1
|
||||||
|
pfx = f"{parent.get_path()}.{n}.{lora_adapter.target.__class__.__name__}"
|
||||||
|
accum[f"{pfx}.down.weight"] = lora.down.weight
|
||||||
|
accum[f"{pfx}.up.weight"] = lora.up.weight
|
||||||
|
prev_parent = parent
|
||||||
|
|
||||||
|
def get_lora_weights(self, name: str) -> dict[str, Tensor]:
|
||||||
|
r: dict[str, Tensor] = {}
|
||||||
|
self._get_lora_weights(self.unet, name, r)
|
||||||
|
self._get_lora_weights(self.clip_text_encoder, name, r)
|
||||||
|
return r
|
||||||
|
|
||||||
def add_loras_to_text_encoder(
|
def add_loras_to_text_encoder(
|
||||||
self,
|
self,
|
||||||
loras: dict[str, Lora[Any]],
|
loras: dict[str, Lora[Any]],
|
||||||
|
|
Loading…
Reference in a new issue