mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
ability to get LoRA weights in SDLoraManager
This commit is contained in:
parent
fb90b00e75
commit
cd5fa97c20
|
@ -92,6 +92,25 @@ class SDLoraManager:
|
|||
# set the scale of the LoRA
|
||||
self.set_scale(name, scale)
|
||||
|
||||
def _get_lora_weights(self, base: fl.Chain, name: str, accum: dict[str, Tensor]) -> None:
|
||||
prev_parent: fl.Chain | None = None
|
||||
n = 0
|
||||
for lora_adapter, parent in base.walk(LoraAdapter):
|
||||
lora = next((l for l in lora_adapter.lora_layers if l.name == name), None)
|
||||
if lora is None:
|
||||
continue
|
||||
n = (parent == prev_parent) and n + 1 or 1
|
||||
pfx = f"{parent.get_path()}.{n}.{lora_adapter.target.__class__.__name__}"
|
||||
accum[f"{pfx}.down.weight"] = lora.down.weight
|
||||
accum[f"{pfx}.up.weight"] = lora.up.weight
|
||||
prev_parent = parent
|
||||
|
||||
def get_lora_weights(self, name: str) -> dict[str, Tensor]:
|
||||
r: dict[str, Tensor] = {}
|
||||
self._get_lora_weights(self.unet, name, r)
|
||||
self._get_lora_weights(self.clip_text_encoder, name, r)
|
||||
return r
|
||||
|
||||
def add_loras_to_text_encoder(
|
||||
self,
|
||||
loras: dict[str, Lora[Any]],
|
||||
|
|
Loading…
Reference in a new issue