diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 0cfefd9..e70190c 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -92,6 +92,25 @@ class SDLoraManager: # set the scale of the LoRA self.set_scale(name, scale) + def _get_lora_weights(self, base: fl.Chain, name: str, accum: dict[str, Tensor]) -> None: + prev_parent: fl.Chain | None = None + n = 0 + for lora_adapter, parent in base.walk(LoraAdapter): + lora = next((l for l in lora_adapter.lora_layers if l.name == name), None) + if lora is None: + continue + n = (parent == prev_parent) and n + 1 or 1 + pfx = f"{parent.get_path()}.{n}.{lora_adapter.target.__class__.__name__}" + accum[f"{pfx}.down.weight"] = lora.down.weight + accum[f"{pfx}.up.weight"] = lora.up.weight + prev_parent = parent + + def get_lora_weights(self, name: str) -> dict[str, Tensor]: + r: dict[str, Tensor] = {} + self._get_lora_weights(self.unet, name, r) + self._get_lora_weights(self.clip_text_encoder, name, r) + return r + def add_loras_to_text_encoder( self, loras: dict[str, Lora[Any]],