mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
simplify LCM weights loader using new manager features
This commit is contained in:
parent
ccd9414ff1
commit
4259261f17
|
@ -115,6 +115,7 @@ class SDLoraManager:
|
|||
include: list[str] | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
preprocess: dict[str, str] | None = None,
|
||||
debug_map: list[tuple[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""Add multiple LoRAs to the U-Net.
|
||||
|
||||
|
@ -144,9 +145,9 @@ class SDLoraManager:
|
|||
|
||||
for exc, v in preprocess.items():
|
||||
ls = {k: v for k, v in loras_excluded.items() if exc in k}
|
||||
auto_attach_loras(ls, self.unet, include=[v])
|
||||
auto_attach_loras(ls, self.unet, include=[v], debug_map=debug_map)
|
||||
|
||||
auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include)
|
||||
auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include, debug_map=debug_map)
|
||||
|
||||
def remove_loras(self, *names: str) -> None:
|
||||
"""Remove multiple LoRAs from the target.
|
||||
|
|
|
@ -66,26 +66,11 @@ def add_lcm_lora(
|
|||
debug_map=debug_map,
|
||||
)
|
||||
|
||||
exclusions = {
|
||||
"res": "ResidualBlock",
|
||||
"downsample": "Downsample",
|
||||
"upsample": "Upsample",
|
||||
}
|
||||
loras_excluded = {k: v for k, v in loras.items() if any(x in k for x in exclusions.keys())}
|
||||
loras_remaining = {k: v for k, v in loras.items() if k not in loras_excluded and k not in loras_projs}
|
||||
|
||||
auto_attach_loras(
|
||||
loras_remaining,
|
||||
unet,
|
||||
exclude=[*exclusions.values(), "TimestepEncoder"],
|
||||
manager.add_loras_to_unet(
|
||||
{k: v for k, v in loras.items() if k not in loras_projs},
|
||||
debug_map=debug_map,
|
||||
)
|
||||
|
||||
# Process exclusions one by one to avoid mixing them up.
|
||||
for exc, v in exclusions.items():
|
||||
ls = {k: v for k, v in loras_excluded.items() if exc in k}
|
||||
auto_attach_loras(ls, unet, include=[v], debug_map=debug_map)
|
||||
|
||||
if debug_map is not None:
|
||||
_check_validity(debug_map)
|
||||
|
||||
|
|
Loading…
Reference in a new issue