From 8c7fcbc00fd22d5bd1bc50d8af6a550e3ce99589 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 5 Mar 2024 16:51:02 +0100 Subject: [PATCH] LoRA manager: move exclude / include to add_loras call Always exclude the TimestepEncoder by default. This is because some keys include both e.g. `resnet` and `time_emb_proj`. Preprocess blocks that tend to mix up with others in a separate auto_attach call. --- .../foundationals/latent_diffusion/lora.py | 68 +++++++++++-------- .../stable_diffusion_xl/lcm_lora.py | 1 - tests/e2e/test_diffusion.py | 3 +- 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 94e231f..6d4c6e8 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -18,10 +18,6 @@ class SDLoraManager: def __init__( self, target: LatentDiffusionModel, - unet_inclusions: list[str] | None = None, - unet_exclusions: list[str] | None = None, - text_encoder_inclusions: list[str] | None = None, - text_encoder_exclusions: list[str] | None = None, ) -> None: """Initialize the LoRA manager. @@ -29,10 +25,6 @@ class SDLoraManager: target: The target model to manage the LoRAs for. """ self.target = target - self.unet_inclusions = unet_inclusions - self.unet_exclusions = unet_exclusions - self.text_encoder_inclusions = text_encoder_inclusions - self.text_encoder_exclusions = text_encoder_exclusions @property def unet(self) -> fl.Chain: @@ -54,6 +46,11 @@ class SDLoraManager: /, tensors: dict[str, Tensor], scale: float = 1.0, + unet_inclusions: list[str] | None = None, + unet_exclusions: list[str] | None = None, + unet_preprocess: dict[str, str] | None = None, + text_encoder_inclusions: list[str] | None = None, + text_encoder_exclusions: list[str] | None = None, ) -> None: """Load a single LoRA from a `state_dict`. @@ -89,13 +86,19 @@ class SDLoraManager: loras = {f"unet_{key}": value for key, value in loras.items()} # attach the LoRA to the target - self.add_loras_to_unet(loras) - self.add_loras_to_text_encoder(loras) + self.add_loras_to_unet(loras, include=unet_inclusions, exclude=unet_exclusions, preprocess=unet_preprocess) + self.add_loras_to_text_encoder(loras, include=text_encoder_inclusions, exclude=text_encoder_exclusions) # set the scale of the LoRA self.set_scale(name, scale) - def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None: + def add_loras_to_text_encoder( + self, + loras: dict[str, Lora[Any]], + /, + include: list[str] | None = None, + exclude: list[str] | None = None, + ) -> None: """Add multiple LoRAs to the text encoder. Args: @@ -103,14 +106,16 @@ class SDLoraManager: (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) """ text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} - auto_attach_loras( - text_encoder_loras, - self.clip_text_encoder, - exclude=self.text_encoder_exclusions, - include=self.text_encoder_inclusions, - ) + auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include) - def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: + def add_loras_to_unet( + self, + loras: dict[str, Lora[Any]], + /, + include: list[str] | None = None, + exclude: list[str] | None = None, + preprocess: dict[str, str] | None = None, + ) -> None: """Add multiple LoRAs to the U-Net. Args: @@ -119,20 +124,29 @@ class SDLoraManager: """ unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} - if self.unet_exclusions is None: - auto_exclusions = { - "time": "TimestepEncoder", + if exclude is None: + exclude = ["TimestepEncoder"] + + if preprocess is None: + preprocess = { "res": "ResidualBlock", "downsample": "Downsample", "upsample": "Upsample", } - exclusions = [ - block for s, block in auto_exclusions.items() if all([s not in key for key in unet_loras.keys()]) - ] - else: - exclusions = self.unet_exclusions - auto_attach_loras(unet_loras, self.unet, exclude=exclusions, include=self.unet_inclusions) + if include is not None: + preprocess = {k: v for k, v in preprocess.items() if v in include} + + preprocess = {k: v for k, v in preprocess.items() if v not in exclude} + + loras_excluded = {k: v for k, v in unet_loras.items() if any(x in k for x in preprocess.keys())} + loras_remaining = {k: v for k, v in unet_loras.items() if k not in loras_excluded} + + for exc, v in preprocess.items(): + ls = {k: v for k, v in loras_excluded.items() if exc in k} + auto_attach_loras(ls, self.unet, include=[v]) + + auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include) def remove_loras(self, *names: str) -> None: """Remove multiple LoRAs from the target. diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py index a0b1c22..02fca11 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/lcm_lora.py @@ -66,7 +66,6 @@ def add_lcm_lora( debug_map=debug_map, ) - # Do *not* check for time because some keys include both `resnets` and `time_emb_proj`. exclusions = { "res": "ResidualBlock", "downsample": "Downsample", diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 69465e0..c98060c 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -1394,8 +1394,7 @@ def test_diffusion_sdxl_lora( prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" - manager = SDLoraManager(sdxl, unet_inclusions=["CrossAttentionBlock"]) - manager.add_loras("dpo", lora_weights, scale=lora_scale) + SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale, unet_inclusions=["CrossAttentionBlock"]) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt