diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 6d4c6e8..8fe78cc 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -115,6 +115,7 @@ class SDLoraManager: include: list[str] | None = None, exclude: list[str] | None = None, preprocess: dict[str, str] | None = None, + debug_map: list[tuple[str, str]] | None = None, ) -> None: """Add multiple LoRAs to the U-Net. @@ -144,9 +145,9 @@ class SDLoraManager: for exc, v in preprocess.items(): ls = {k: v for k, v in loras_excluded.items() if exc in k} - auto_attach_loras(ls, self.unet, include=[v]) + auto_attach_loras(ls, self.unet, include=[v], debug_map=debug_map) - auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include) + auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include, debug_map=debug_map) def remove_loras(self, *names: str) -> None: """Remove multiple LoRAs from the target. diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py index 02fca11..2513bc4 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py @@ -66,26 +66,11 @@ def add_lcm_lora( debug_map=debug_map, ) - exclusions = { - "res": "ResidualBlock", - "downsample": "Downsample", - "upsample": "Upsample", - } - loras_excluded = {k: v for k, v in loras.items() if any(x in k for x in exclusions.keys())} - loras_remaining = {k: v for k, v in loras.items() if k not in loras_excluded and k not in loras_projs} - - auto_attach_loras( - loras_remaining, - unet, - exclude=[*exclusions.values(), "TimestepEncoder"], + manager.add_loras_to_unet( + {k: v for k, v in loras.items() if k not in loras_projs}, debug_map=debug_map, ) - # Process exclusions one by one to avoid mixing them up. - for exc, v in exclusions.items(): - ls = {k: v for k, v in loras_excluded.items() if exc in k} - auto_attach_loras(ls, unet, include=[v], debug_map=debug_map) - if debug_map is not None: _check_validity(debug_map)