mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
LoRA manager: move exclude / include to add_loras call
Always exclude the TimestepEncoder by default. This is because some keys include both e.g. `resnet` and `time_emb_proj`. Preprocess blocks that tend to mix up with others in a separate auto_attach call.
This commit is contained in:
parent
052a20b897
commit
8c7fcbc00f
|
@ -18,10 +18,6 @@ class SDLoraManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: LatentDiffusionModel,
|
target: LatentDiffusionModel,
|
||||||
unet_inclusions: list[str] | None = None,
|
|
||||||
unet_exclusions: list[str] | None = None,
|
|
||||||
text_encoder_inclusions: list[str] | None = None,
|
|
||||||
text_encoder_exclusions: list[str] | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the LoRA manager.
|
"""Initialize the LoRA manager.
|
||||||
|
|
||||||
|
@ -29,10 +25,6 @@ class SDLoraManager:
|
||||||
target: The target model to manage the LoRAs for.
|
target: The target model to manage the LoRAs for.
|
||||||
"""
|
"""
|
||||||
self.target = target
|
self.target = target
|
||||||
self.unet_inclusions = unet_inclusions
|
|
||||||
self.unet_exclusions = unet_exclusions
|
|
||||||
self.text_encoder_inclusions = text_encoder_inclusions
|
|
||||||
self.text_encoder_exclusions = text_encoder_exclusions
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unet(self) -> fl.Chain:
|
def unet(self) -> fl.Chain:
|
||||||
|
@ -54,6 +46,11 @@ class SDLoraManager:
|
||||||
/,
|
/,
|
||||||
tensors: dict[str, Tensor],
|
tensors: dict[str, Tensor],
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
|
unet_inclusions: list[str] | None = None,
|
||||||
|
unet_exclusions: list[str] | None = None,
|
||||||
|
unet_preprocess: dict[str, str] | None = None,
|
||||||
|
text_encoder_inclusions: list[str] | None = None,
|
||||||
|
text_encoder_exclusions: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load a single LoRA from a `state_dict`.
|
"""Load a single LoRA from a `state_dict`.
|
||||||
|
|
||||||
|
@ -89,13 +86,19 @@ class SDLoraManager:
|
||||||
loras = {f"unet_{key}": value for key, value in loras.items()}
|
loras = {f"unet_{key}": value for key, value in loras.items()}
|
||||||
|
|
||||||
# attach the LoRA to the target
|
# attach the LoRA to the target
|
||||||
self.add_loras_to_unet(loras)
|
self.add_loras_to_unet(loras, include=unet_inclusions, exclude=unet_exclusions, preprocess=unet_preprocess)
|
||||||
self.add_loras_to_text_encoder(loras)
|
self.add_loras_to_text_encoder(loras, include=text_encoder_inclusions, exclude=text_encoder_exclusions)
|
||||||
|
|
||||||
# set the scale of the LoRA
|
# set the scale of the LoRA
|
||||||
self.set_scale(name, scale)
|
self.set_scale(name, scale)
|
||||||
|
|
||||||
def add_loras_to_text_encoder(self, loras: dict[str, Lora[Any]], /) -> None:
|
def add_loras_to_text_encoder(
|
||||||
|
self,
|
||||||
|
loras: dict[str, Lora[Any]],
|
||||||
|
/,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
exclude: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
"""Add multiple LoRAs to the text encoder.
|
"""Add multiple LoRAs to the text encoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -103,14 +106,16 @@ class SDLoraManager:
|
||||||
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||||
"""
|
"""
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
auto_attach_loras(
|
auto_attach_loras(text_encoder_loras, self.clip_text_encoder, exclude=exclude, include=include)
|
||||||
text_encoder_loras,
|
|
||||||
self.clip_text_encoder,
|
|
||||||
exclude=self.text_encoder_exclusions,
|
|
||||||
include=self.text_encoder_inclusions,
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_loras_to_unet(self, loras: dict[str, Lora[Any]], /) -> None:
|
def add_loras_to_unet(
|
||||||
|
self,
|
||||||
|
loras: dict[str, Lora[Any]],
|
||||||
|
/,
|
||||||
|
include: list[str] | None = None,
|
||||||
|
exclude: list[str] | None = None,
|
||||||
|
preprocess: dict[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
"""Add multiple LoRAs to the U-Net.
|
"""Add multiple LoRAs to the U-Net.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -119,20 +124,29 @@ class SDLoraManager:
|
||||||
"""
|
"""
|
||||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||||
|
|
||||||
if self.unet_exclusions is None:
|
if exclude is None:
|
||||||
auto_exclusions = {
|
exclude = ["TimestepEncoder"]
|
||||||
"time": "TimestepEncoder",
|
|
||||||
|
if preprocess is None:
|
||||||
|
preprocess = {
|
||||||
"res": "ResidualBlock",
|
"res": "ResidualBlock",
|
||||||
"downsample": "Downsample",
|
"downsample": "Downsample",
|
||||||
"upsample": "Upsample",
|
"upsample": "Upsample",
|
||||||
}
|
}
|
||||||
exclusions = [
|
|
||||||
block for s, block in auto_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
exclusions = self.unet_exclusions
|
|
||||||
|
|
||||||
auto_attach_loras(unet_loras, self.unet, exclude=exclusions, include=self.unet_inclusions)
|
if include is not None:
|
||||||
|
preprocess = {k: v for k, v in preprocess.items() if v in include}
|
||||||
|
|
||||||
|
preprocess = {k: v for k, v in preprocess.items() if v not in exclude}
|
||||||
|
|
||||||
|
loras_excluded = {k: v for k, v in unet_loras.items() if any(x in k for x in preprocess.keys())}
|
||||||
|
loras_remaining = {k: v for k, v in unet_loras.items() if k not in loras_excluded}
|
||||||
|
|
||||||
|
for exc, v in preprocess.items():
|
||||||
|
ls = {k: v for k, v in loras_excluded.items() if exc in k}
|
||||||
|
auto_attach_loras(ls, self.unet, include=[v])
|
||||||
|
|
||||||
|
auto_attach_loras(loras_remaining, self.unet, exclude=exclude, include=include)
|
||||||
|
|
||||||
def remove_loras(self, *names: str) -> None:
|
def remove_loras(self, *names: str) -> None:
|
||||||
"""Remove multiple LoRAs from the target.
|
"""Remove multiple LoRAs from the target.
|
||||||
|
|
|
@ -66,7 +66,6 @@ def add_lcm_lora(
|
||||||
debug_map=debug_map,
|
debug_map=debug_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Do *not* check for time because some keys include both `resnets` and `time_emb_proj`.
|
|
||||||
exclusions = {
|
exclusions = {
|
||||||
"res": "ResidualBlock",
|
"res": "ResidualBlock",
|
||||||
"downsample": "Downsample",
|
"downsample": "Downsample",
|
||||||
|
|
|
@ -1394,8 +1394,7 @@ def test_diffusion_sdxl_lora(
|
||||||
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography"
|
||||||
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white"
|
||||||
|
|
||||||
manager = SDLoraManager(sdxl, unet_inclusions=["CrossAttentionBlock"])
|
SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale, unet_inclusions=["CrossAttentionBlock"])
|
||||||
manager.add_loras("dpo", lora_weights, scale=lora_scale)
|
|
||||||
|
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
text=prompt, negative_text=negative_prompt
|
text=prompt, negative_text=negative_prompt
|
||||||
|
|
Loading…
Reference in a new issue