mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
(doc/foundationals) add SDLoraManager
, related docstrings
This commit is contained in:
parent
7406d8e01f
commit
6b35f1cc84
|
@ -5,3 +5,5 @@
|
||||||
::: refiners.foundationals.latent_diffusion.stable_diffusion_1
|
::: refiners.foundationals.latent_diffusion.stable_diffusion_1
|
||||||
|
|
||||||
::: refiners.foundationals.latent_diffusion.solvers
|
::: refiners.foundationals.latent_diffusion.solvers
|
||||||
|
|
||||||
|
::: refiners.foundationals.latent_diffusion.lora
|
||||||
|
|
|
@ -8,20 +8,34 @@ from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
|
||||||
class SDLoraManager:
|
class SDLoraManager:
|
||||||
|
"""Manage LoRAs for a Stable Diffusion model.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In the context of SDLoraManager, a "LoRA" is a set of ["LoRA layers"][refiners.fluxion.adapters.lora.Lora]
|
||||||
|
that can be attached to a target model.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: LatentDiffusionModel,
|
target: LatentDiffusionModel,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the LoRA manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target model to manage the LoRAs for.
|
||||||
|
"""
|
||||||
self.target = target
|
self.target = target
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unet(self) -> fl.Chain:
|
def unet(self) -> fl.Chain:
|
||||||
|
"""The Stable Diffusion's U-Net model."""
|
||||||
unet = self.target.unet
|
unet = self.target.unet
|
||||||
assert isinstance(unet, fl.Chain)
|
assert isinstance(unet, fl.Chain)
|
||||||
return unet
|
return unet
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clip_text_encoder(self) -> fl.Chain:
|
def clip_text_encoder(self) -> fl.Chain:
|
||||||
|
"""The Stable Diffusion's text encoder."""
|
||||||
clip_text_encoder = self.target.clip_text_encoder
|
clip_text_encoder = self.target.clip_text_encoder
|
||||||
assert isinstance(clip_text_encoder, fl.Chain)
|
assert isinstance(clip_text_encoder, fl.Chain)
|
||||||
return clip_text_encoder
|
return clip_text_encoder
|
||||||
|
@ -33,23 +47,44 @@ class SDLoraManager:
|
||||||
tensors: dict[str, Tensor],
|
tensors: dict[str, Tensor],
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load the LoRA weights from a dictionary of tensors.
|
"""Load a single LoRA from a `state_dict`.
|
||||||
|
|
||||||
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
Warning:
|
||||||
|
This method expects the keys of the `state_dict` to be in the commonly found formats on CivitAI's hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
tensors: The `state_dict` of the LoRA to load.
|
||||||
|
scale: The scale to use for the LoRA.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the Manager already has a LoRA with the same name.
|
||||||
"""
|
"""
|
||||||
assert name not in self.names, f"LoRA {name} already exists"
|
assert name not in self.names, f"LoRA {name} already exists"
|
||||||
|
|
||||||
|
# load LoRA the state_dict
|
||||||
loras = Lora.from_dict(
|
loras = Lora.from_dict(
|
||||||
name, {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
name,
|
||||||
|
state_dict={
|
||||||
|
key: value.to(
|
||||||
|
device=self.target.device,
|
||||||
|
dtype=self.target.dtype,
|
||||||
)
|
)
|
||||||
|
for key, value in tensors.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# sort all the LoRA's keys using the `sort_keys` method
|
||||||
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||||
|
|
||||||
# if no key contains "unet" or "text", assume all keys are for the unet
|
# if no key contains "unet" or "text", assume all keys are for the unet
|
||||||
if all("unet" not in key and "text" not in key for key in loras.keys()):
|
if all("unet" not in key and "text" not in key for key in loras.keys()):
|
||||||
loras = {f"unet_{key}": value for key, value in loras.items()}
|
loras = {f"unet_{key}": value for key, value in loras.items()}
|
||||||
|
|
||||||
|
# attach the LoRA to the target
|
||||||
self.add_loras_to_unet(loras)
|
self.add_loras_to_unet(loras)
|
||||||
self.add_loras_to_text_encoder(loras)
|
self.add_loras_to_text_encoder(loras)
|
||||||
|
|
||||||
|
# set the scale of the LoRA
|
||||||
self.set_scale(name, scale)
|
self.set_scale(name, scale)
|
||||||
|
|
||||||
def add_multiple_loras(
|
def add_multiple_loras(
|
||||||
|
@ -58,14 +93,36 @@ class SDLoraManager:
|
||||||
tensors: dict[str, dict[str, Tensor]],
|
tensors: dict[str, dict[str, Tensor]],
|
||||||
scale: dict[str, float] | None = None,
|
scale: dict[str, float] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Load multiple LoRAs from a dictionary of `state_dict`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: The dictionary of `state_dict` of the LoRAs to load
|
||||||
|
(keys are the names of the LoRAs, values are the `state_dict` of the LoRAs).
|
||||||
|
scale: The scales to use for the LoRAs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the manager already has a LoRA with the same name.
|
||||||
|
"""
|
||||||
for name, lora_tensors in tensors.items():
|
for name, lora_tensors in tensors.items():
|
||||||
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
|
self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0)
|
||||||
|
|
||||||
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||||
|
"""Add multiple LoRAs to the text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras: The dictionary of LoRAs to add to the text encoder.
|
||||||
|
(keys are the names of the LoRAs, values are the LoRAs to add to the text encoder)
|
||||||
|
"""
|
||||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||||
|
|
||||||
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
|
def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None:
|
||||||
|
"""Add multiple LoRAs to the U-Net.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loras: The dictionary of LoRAs to add to the U-Net.
|
||||||
|
(keys are the names of the LoRAs, values are the LoRAs to add to the U-Net)
|
||||||
|
"""
|
||||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||||
exclude = [
|
exclude = [
|
||||||
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
block for s, block in self.unet_exclusions.items() if all([s not in key for key in unet_loras.keys()])
|
||||||
|
@ -73,6 +130,11 @@ class SDLoraManager:
|
||||||
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
||||||
|
|
||||||
def remove_loras(self, *names: str) -> None:
|
def remove_loras(self, *names: str) -> None:
|
||||||
|
"""Remove mulitple LoRAs from the target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
names: The names of the LoRAs to remove.
|
||||||
|
"""
|
||||||
for lora_adapter in self.lora_adapters:
|
for lora_adapter in self.lora_adapters:
|
||||||
for name in names:
|
for name in names:
|
||||||
lora_adapter.remove_lora(name)
|
lora_adapter.remove_lora(name)
|
||||||
|
@ -81,21 +143,47 @@ class SDLoraManager:
|
||||||
lora_adapter.eject()
|
lora_adapter.eject()
|
||||||
|
|
||||||
def remove_all(self) -> None:
|
def remove_all(self) -> None:
|
||||||
|
"""Remove all the LoRAs from the target."""
|
||||||
for lora_adapter in self.lora_adapters:
|
for lora_adapter in self.lora_adapters:
|
||||||
lora_adapter.eject()
|
lora_adapter.eject()
|
||||||
|
|
||||||
def get_loras_by_name(self, name: str, /) -> list[Lora]:
|
def get_loras_by_name(self, name: str, /) -> list[Lora]:
|
||||||
|
"""Get the LoRA layers with the given name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
"""
|
||||||
return [lora for lora in self.loras if lora.name == name]
|
return [lora for lora in self.loras if lora.name == name]
|
||||||
|
|
||||||
def get_scale(self, name: str, /) -> float:
|
def get_scale(self, name: str, /) -> float:
|
||||||
|
"""Get the scale of the LoRA with the given name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The scale of the LoRA layers with the given name.
|
||||||
|
"""
|
||||||
loras = self.get_loras_by_name(name)
|
loras = self.get_loras_by_name(name)
|
||||||
assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same"
|
assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same"
|
||||||
return loras[0].scale
|
return loras[0].scale
|
||||||
|
|
||||||
def set_scale(self, name: str, scale: float, /) -> None:
|
def set_scale(self, name: str, scale: float, /) -> None:
|
||||||
|
"""Set the scale of the LoRA with the given name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the LoRA.
|
||||||
|
scale: The new scale to set.
|
||||||
|
"""
|
||||||
self.update_scales({name: scale})
|
self.update_scales({name: scale})
|
||||||
|
|
||||||
def update_scales(self, scales: dict[str, float], /) -> None:
|
def update_scales(self, scales: dict[str, float], /) -> None:
|
||||||
|
"""Update the scales of mulitple LoRAs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scales: The scales to update.
|
||||||
|
(keys are the names of the LoRAs, values are the new scales to set)
|
||||||
|
"""
|
||||||
assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}"
|
assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}"
|
||||||
for name, scale in scales.items():
|
for name, scale in scales.items():
|
||||||
for lora in self.get_loras_by_name(name):
|
for lora in self.get_loras_by_name(name):
|
||||||
|
@ -103,14 +191,17 @@ class SDLoraManager:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loras(self) -> list[Lora]:
|
def loras(self) -> list[Lora]:
|
||||||
|
"""List of all the LoRA layers managed by the SDLoraManager."""
|
||||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
|
"""List of all the LoRA names managed the SDLoraManager"""
|
||||||
return list(set(lora.name for lora in self.loras))
|
return list(set(lora.name for lora in self.loras))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_adapters(self) -> list[LoraAdapter]:
|
def lora_adapters(self) -> list[LoraAdapter]:
|
||||||
|
"""List of all the LoraAdapters managed by the SDLoraManager."""
|
||||||
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -124,6 +215,7 @@ class SDLoraManager:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scales(self) -> dict[str, float]:
|
def scales(self) -> dict[str, float]:
|
||||||
|
"""The scales of all the LoRAs managed by the SDLoraManager."""
|
||||||
return {name: self.get_scale(name) for name in self.names}
|
return {name: self.get_scale(name) for name in self.names}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in a new issue