From ed8ec26e633de6915ae3f3f865cbe19fef227f37 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 5 Mar 2024 15:14:21 +0100 Subject: [PATCH] allow passing inclusions and exlusions to SDLoraManager --- .../foundationals/latent_diffusion/lora.py | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 822a0c4..862b1dc 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -18,6 +18,10 @@ class SDLoraManager: def __init__( self, target: LatentDiffusionModel, + unet_inclusions: list[str] | None = None, + unet_exclusions: list[str] | None = None, + text_encoder_inclusions: list[str] | None = None, + text_encoder_exclusions: list[str] | None = None, ) -> None: """Initialize the LoRA manager. @@ -25,6 +29,10 @@ class SDLoraManager: target: The target model to manage the LoRAs for. """ self.target = target + self.unet_inclusions = unet_inclusions + self.unet_exclusions = unet_exclusions + self.text_encoder_inclusions = text_encoder_inclusions + self.text_encoder_exclusions = text_encoder_exclusions @property def unet(self) -> fl.Chain: @@ -114,7 +122,12 @@ class SDLoraManager: (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) """ text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} - auto_attach_loras(text_encoder_loras, self.clip_text_encoder) + auto_attach_loras( + text_encoder_loras, + self.clip_text_encoder, + exclude=self.text_encoder_exclusions, + include=self.text_encoder_inclusions, + ) def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: """Add multiple LoRAs to the U-Net. @@ -124,10 +137,21 @@ class SDLoraManager: (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net) """ unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} - exclude = [ - block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) - ] - auto_attach_loras(unet_loras, self.unet, exclude=exclude) + + if self.unet_exclusions is None: + auto_exclusions = { + "time": "TimestepEncoder", + "res": "ResidualBlock", + "downsample": "Downsample", + "upsample": "Upsample", + } + exclusions = [ + block for s, block in auto_exclusions.items() if all([s not in key for key in unet_loras.keys()]) + ] + else: + exclusions = self.unet_exclusions + + auto_attach_loras(unet_loras, self.unet, exclude=exclusions, include=self.unet_inclusions) def remove_loras(self, *names: str) -> None: """Remove multiple LoRAs from the target. @@ -206,15 +230,6 @@ class SDLoraManager: """List of all the LoraAdapters managed by the SDLoraManager.""" return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter)) - @property - def unet_exclusions(self) -> dict[str, str]: - return { - "time": "TimestepEncoder", - "res": "ResidualBlock", - "downsample": "Downsample", - "upsample": "Upsample", - } - @property def scales(self) -> dict[str, float]: """The scales of all the LoRAs managed by the SDLoraManager."""