simplify LCM weights loader using new manager features

This commit is contained in:
Pierre Chapuis 2024-03-05 18:53:17 +01:00
parent ccd9414ff1
commit 4259261f17
2 changed files with 5 additions and 19 deletions

View file

@ -115,6 +115,7 @@ class SDLoraManager:
include: list[str] | None = None, include: list[str] | None = None,
exclude: list[str] | None = None, exclude: list[str] | None = None,
preprocess: dict[str, str] | None = None, preprocess: dict[str, str] | None = None,
debug_map: list[tuple[str, str]] | None = None,
) -> None: ) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net.
@ -144,9 +145,9 @@ class SDLoraManager:
for exc, v in preprocess.items(): for exc, v in preprocess.items():
ls = {k: v for k, v in loras_excluded.items() if exc in k} ls = {k: v for k, v in loras_excluded.items() if exc in k}
auto_attach_loras(ls, self.unet, include=[v]) auto_attach_loras(ls, self.unet, include=[v], debug_map=debug_map)
auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include) auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include, debug_map=debug_map)
def remove_loras(self, *names: str) -> None: def remove_loras(self, *names: str) -> None:
"""Remove multiple LoRAs from the target. """Remove multiple LoRAs from the target.

View file

@ -66,26 +66,11 @@ def add_lcm_lora(
debug_map=debug_map, debug_map=debug_map,
) )
exclusions = { manager.add_loras_to_unet(
"res": "ResidualBlock", {k: v for k, v in loras.items() if k not in loras_projs},
"downsample": "Downsample",
"upsample": "Upsample",
}
loras_excluded = {k: v for k, v in loras.items() if any(x in k for x in exclusions.keys())}
loras_remaining = {k: v for k, v in loras.items() if k not in loras_excluded and k not in loras_projs}
auto_attach_loras(
loras_remaining,
unet,
exclude=[*exclusions.values(), "TimestepEncoder"],
debug_map=debug_map, debug_map=debug_map,
) )
# Process exclusions one by one to avoid mixing them up.
for exc, v in exclusions.items():
ls = {k: v for k, v in loras_excluded.items() if exc in k}
auto_attach_loras(ls, unet, include=[v], debug_map=debug_map)
if debug_map is not None: if debug_map is not None:
_check_validity(debug_map) _check_validity(debug_map)