LoRA manager: move exclude / include to add_loras call

Always exclude the TimestepEncoder by default.
This is because some keys include both e.g. `resnet` and `time_emb_proj`.

Preprocess blocks that tend to mix up with others in a separate
auto_attach call.
This commit is contained in:
Pierre Chapuis 2024-03-05 16:51:02 +01:00
parent 052a20b897
commit 8c7fcbc00f
3 changed files with 42 additions and 30 deletions

View file

@ -18,10 +18,6 @@ class SDLoraManager:
def __init__( def __init__(
self, self,
target: LatentDiffusionModel, target: LatentDiffusionModel,
unet_inclusions: list[str] | None = None,
unet_exclusions: list[str] | None = None,
text_encoder_inclusions: list[str] | None = None,
text_encoder_exclusions: list[str] | None = None,
) -> None: ) -> None:
"""Initialize the LoRA manager. """Initialize the LoRA manager.
@ -29,10 +25,6 @@ class SDLoraManager:
target: The target model to manage the LoRAs for. target: The target model to manage the LoRAs for.
""" """
self.target = target self.target = target
self.unet_inclusions = unet_inclusions
self.unet_exclusions = unet_exclusions
self.text_encoder_inclusions = text_encoder_inclusions
self.text_encoder_exclusions = text_encoder_exclusions
@property @property
def unet(self) -> fl.Chain: def unet(self) -> fl.Chain:
@ -54,6 +46,11 @@ class SDLoraManager:
/, /,
tensors: dict[str, Tensor], tensors: dict[str, Tensor],
scale: float = 1.0, scale: float = 1.0,
unet_inclusions: list[str] | None = None,
unet_exclusions: list[str] | None = None,
unet_preprocess: dict[str, str] | None = None,
text_encoder_inclusions: list[str] | None = None,
text_encoder_exclusions: list[str] | None = None,
) -> None: ) -> None:
"""Load a single LoRA from a `state_dict`. """Load a single LoRA from a `state_dict`.
@ -89,13 +86,19 @@ class SDLoraManager:
loras = {f"unet_{key}": value for key, value in loras.items()} loras = {f"unet_{key}": value for key, value in loras.items()}
# attach the LoRA to the target # attach the LoRA to the target
self.add_loras_to_unet(loras) self.add_loras_to_unet(loras, include=unet_inclusions, exclude=unet_exclusions, preprocess=unet_preprocess)
self.add_loras_to_text_encoder(loras) self.add_loras_to_text_encoder(loras, include=text_encoder_inclusions, exclude=text_encoder_exclusions)
# set the scale of the LoRA # set the scale of the LoRA
self.set_scale(name, scale) self.set_scale(name, scale)
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None: def add_loras_to_text_encoder(
self,
loras: dict[str, Lora[Any]],
/,
include: list[str] | None = None,
exclude: list[str] | None = None,
) -> None:
"""Add multiple LoRAs to the text encoder. """Add multiple LoRAs to the text encoder.
Args: Args:
@ -103,14 +106,16 @@ class SDLoraManager:
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
""" """
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
auto_attach_loras( auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include)
text_encoder_loras,
self.clip_text_encoder,
exclude=self.text_encoder_exclusions,
include=self.text_encoder_inclusions,
)
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: def add_loras_to_unet(
self,
loras: dict[str, Lora[Any]],
/,
include: list[str] | None = None,
exclude: list[str] | None = None,
preprocess: dict[str, str] | None = None,
) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net.
Args: Args:
@ -119,20 +124,29 @@ class SDLoraManager:
""" """
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
if self.unet_exclusions is None: if exclude is None:
auto_exclusions = { exclude = ["TimestepEncoder"]
"time": "TimestepEncoder",
if preprocess is None:
preprocess = {
"res": "ResidualBlock", "res": "ResidualBlock",
"downsample": "Downsample", "downsample": "Downsample",
"upsample": "Upsample", "upsample": "Upsample",
} }
exclusions = [
block for s, block in auto_exclusions.items() if all([s not in key for key in unet_loras.keys()])
]
else:
exclusions = self.unet_exclusions
auto_attach_loras(unet_loras, self.unet, exclude=exclusions, include=self.unet_inclusions) if include is not None:
preprocess = {k: v for k, v in preprocess.items() if v in include}
preprocess = {k: v for k, v in preprocess.items() if v not in exclude}
loras_excluded = {k: v for k, v in unet_loras.items() if any(x in k for x in preprocess.keys())}
loras_remaining = {k: v for k, v in unet_loras.items() if k not in loras_excluded}
for exc, v in preprocess.items():
ls = {k: v for k, v in loras_excluded.items() if exc in k}
auto_attach_loras(ls, self.unet, include=[v])
auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include)
def remove_loras(self, *names: str) -> None: def remove_loras(self, *names: str) -> None:
"""Remove multiple LoRAs from the target. """Remove multiple LoRAs from the target.

View file

@ -66,7 +66,6 @@ def add_lcm_lora(
debug_map=debug_map, debug_map=debug_map,
) )
# Do *not* check for time because some keys include both `resnets` and `time_emb_proj`.
exclusions = { exclusions = {
"res": "ResidualBlock", "res": "ResidualBlock",
"downsample": "Downsample", "downsample": "Downsample",

View file

@ -1394,8 +1394,7 @@ def test_diffusion_sdxl_lora(
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
manager = SDLoraManager(sdxl, unet_inclusions=["CrossAttentionBlock"]) SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale, unet_inclusions=["CrossAttentionBlock"])
manager.add_loras("dpo", lora_weights, scale=lora_scale)
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt text=prompt, negative_text=negative_prompt