mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
allow passing inclusions and exlusions to SDLoraManager
This commit is contained in:
parent
cce2a98fa6
commit
ed8ec26e63
|
@ -18,6 +18,10 @@ class SDLoraManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: LatentDiffusionModel,
|
target: LatentDiffusionModel,
|
||||||
|
unet_inclusions: list[str] | None = None,
|
||||||
|
unet_exclusions: list[str] | None = None,
|
||||||
|
text_encoder_inclusions: list[str] | None = None,
|
||||||
|
text_encoder_exclusions: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the LoRA manager.
|
"""Initialize the LoRA manager.
|
||||||
|
|
||||||
|
@ -25,6 +29,10 @@ class SDLoraManager:
|
||||||
target: The target model to manage the LoRAs for.
|
target: The target model to manage the LoRAs for.
|
||||||
"""
|
"""
|
||||||
self.target = target
|
self.target = target
|
||||||
|
self.unet_inclusions = unet_inclusions
|
||||||
|
self.unet_exclusions = unet_exclusions
|
||||||
|
self.text_encoder_inclusions = text_encoder_inclusions
|
||||||
|
self.text_encoder_exclusions = text_encoder_exclusions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unet(self) -> fl.Chain:
|
def unet(self) -> fl.Chain:
|
||||||
|
@ -114,7 +122,12 @@ class SDLoraManager:
|
||||||
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||||
"""
|
"""
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
auto_attach_loras(text_encoder_loras, self.clip_text_encoder)
|
auto_attach_loras(
|
||||||
|
text_encoder_loras,
|
||||||
|
self.clip_text_encoder,
|
||||||
|
exclude=self.text_encoder_exclusions,
|
||||||
|
include=self.text_encoder_inclusions,
|
||||||
|
)
|
||||||
|
|
||||||
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
||||||
"""Add multiple LoRAs to the U-Net.
|
"""Add multiple LoRAs to the U-Net.
|
||||||
|
@ -124,10 +137,21 @@ class SDLoraManager:
|
||||||
(keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)
|
(keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)
|
||||||
"""
|
"""
|
||||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||||
exclude = [
|
|
||||||
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
if self.unet_exclusions is None:
|
||||||
]
|
auto_exclusions = {
|
||||||
auto_attach_loras(unet_loras, self.unet, exclude=exclude)
|
"time": "TimestepEncoder",
|
||||||
|
"res": "ResidualBlock",
|
||||||
|
"downsample": "Downsample",
|
||||||
|
"upsample": "Upsample",
|
||||||
|
}
|
||||||
|
exclusions = [
|
||||||
|
block for s, block in auto_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
exclusions = self.unet_exclusions
|
||||||
|
|
||||||
|
auto_attach_loras(unet_loras, self.unet, exclude=exclusions, include=self.unet_inclusions)
|
||||||
|
|
||||||
def remove_loras(self, *names: str) -> None:
|
def remove_loras(self, *names: str) -> None:
|
||||||
"""Remove multiple LoRAs from the target.
|
"""Remove multiple LoRAs from the target.
|
||||||
|
@ -206,15 +230,6 @@ class SDLoraManager:
|
||||||
"""List of all the LoraAdapters managed by the SDLoraManager."""
|
"""List of all the LoraAdapters managed by the SDLoraManager."""
|
||||||
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||||
|
|
||||||
@property
|
|
||||||
def unet_exclusions(self) -> dict[str, str]:
|
|
||||||
return {
|
|
||||||
"time": "TimestepEncoder",
|
|
||||||
"res": "ResidualBlock",
|
|
||||||
"downsample": "Downsample",
|
|
||||||
"upsample": "Upsample",
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scales(self) -> dict[str, float]:
|
def scales(self) -> dict[str, float]:
|
||||||
"""The scales of all the LoRAs managed by the SDLoraManager."""
|
"""The scales of all the LoRAs managed by the SDLoraManager."""
|
||||||
|
|
Loading…
Reference in a new issue