allow passing inclusions and exlusions to SDLoraManager

This commit is contained in:
Pierre Chapuis 2024-03-05 15:14:21 +01:00
parent cce2a98fa6
commit ed8ec26e63

View file

@ -18,6 +18,10 @@ class SDLoraManager:
def __init__( def __init__(
self, self,
target: LatentDiffusionModel, target: LatentDiffusionModel,
unet_inclusions: list[str] | None = None,
unet_exclusions: list[str] | None = None,
text_encoder_inclusions: list[str] | None = None,
text_encoder_exclusions: list[str] | None = None,
) -> None: ) -> None:
"""Initialize the LoRA manager. """Initialize the LoRA manager.
@ -25,6 +29,10 @@ class SDLoraManager:
target: The target model to manage the LoRAs for. target: The target model to manage the LoRAs for.
""" """
self.target = target self.target = target
self.unet_inclusions = unet_inclusions
self.unet_exclusions = unet_exclusions
self.text_encoder_inclusions = text_encoder_inclusions
self.text_encoder_exclusions = text_encoder_exclusions
@property @property
def unet(self) -> fl.Chain: def unet(self) -> fl.Chain:
@ -114,7 +122,12 @@ class SDLoraManager:
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder) (keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
""" """
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
auto_attach_loras(text_encoder_loras, self.clip_text_encoder) auto_attach_loras(
text_encoder_loras,
self.clip_text_encoder,
exclude=self.text_encoder_exclusions,
include=self.text_encoder_inclusions,
)
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None: def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
"""Add multiple LoRAs to the U-Net. """Add multiple LoRAs to the U-Net.
@ -124,10 +137,21 @@ class SDLoraManager:
(keys are the names of the LoRAs, values are the LoRAs to add to the U-Net) (keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)
""" """
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude = [
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()]) if self.unet_exclusions is None:
auto_exclusions = {
"time": "TimestepEncoder",
"res": "ResidualBlock",
"downsample": "Downsample",
"upsample": "Upsample",
}
exclusions = [
block for s, block in auto_exclusions.items() if all([s not in key for key in unet_loras.keys()])
] ]
auto_attach_loras(unet_loras, self.unet, exclude=exclude) else:
exclusions = self.unet_exclusions
auto_attach_loras(unet_loras, self.unet, exclude=exclusions, include=self.unet_inclusions)
def remove_loras(self, *names: str) -> None: def remove_loras(self, *names: str) -> None:
"""Remove multiple LoRAs from the target. """Remove multiple LoRAs from the target.
@ -206,15 +230,6 @@ class SDLoraManager:
"""List of all the LoraAdapters managed by the SDLoraManager.""" """List of all the LoraAdapters managed by the SDLoraManager."""
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter)) return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@property
def unet_exclusions(self) -> dict[str, str]:
return {
"time": "TimestepEncoder",
"res": "ResidualBlock",
"downsample": "Downsample",
"upsample": "Upsample",
}
@property @property
def scales(self) -> dict[str, float]: def scales(self) -> dict[str, float]:
"""The scales of all the LoRAs managed by the SDLoraManager.""" """The scales of all the LoRAs managed by the SDLoraManager."""