2023-08-04 13:28:41 +00:00
|
|
|
import random
|
|
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from loguru import logger
|
|
|
|
from refiners.adapters.lora import LoraAdapter, Lora
|
|
|
|
from refiners.fluxion.utils import save_to_safetensors
|
2023-08-22 16:03:26 +00:00
|
|
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
2023-08-04 13:28:41 +00:00
|
|
|
import refiners.fluxion.layers as fl
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
from refiners.training_utils.callback import Callback
|
|
|
|
from refiners.training_utils.latent_diffusion import (
|
|
|
|
FinetuneLatentDiffusionConfig,
|
|
|
|
TextEmbeddingLatentsBatch,
|
|
|
|
TextEmbeddingLatentsDataset,
|
|
|
|
LatentDiffusionTrainer,
|
|
|
|
LatentDiffusionConfig,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class LoraConfig(BaseModel):
|
|
|
|
rank: int = 32
|
|
|
|
trigger_phrase: str = ""
|
|
|
|
use_only_trigger_probability: float = 0.0
|
|
|
|
unet_targets: list[LoraTarget]
|
|
|
|
text_encoder_targets: list[LoraTarget]
|
|
|
|
lda_targets: list[LoraTarget]
|
|
|
|
|
|
|
|
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
|
2023-08-22 16:03:26 +00:00
|
|
|
for linear, parent in lora_targets(module, target):
|
|
|
|
adapter = LoraAdapter(target=linear, rank=self.rank)
|
|
|
|
adapter.inject(parent)
|
|
|
|
for linear in adapter.Lora.layers(fl.Linear):
|
|
|
|
linear.requires_grad_(requires_grad=True)
|
2023-08-04 13:28:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
trainer: "LoraLatentDiffusionTrainer",
|
|
|
|
) -> None:
|
|
|
|
super().__init__(trainer=trainer)
|
|
|
|
self.trigger_phrase = trainer.config.lora.trigger_phrase
|
|
|
|
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
|
|
|
|
logger.info(f"Trigger phrase: {self.trigger_phrase}")
|
|
|
|
|
|
|
|
def process_caption(self, caption: str) -> str:
|
|
|
|
caption = super().process_caption(caption=caption)
|
|
|
|
if self.trigger_phrase:
|
|
|
|
caption = (
|
|
|
|
f"{self.trigger_phrase} {caption}"
|
|
|
|
if random.random() < self.use_only_trigger_probability
|
|
|
|
else self.trigger_phrase
|
|
|
|
)
|
|
|
|
return caption
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
|
|
|
latent_diffusion: LatentDiffusionConfig
|
|
|
|
lora: LoraConfig
|
|
|
|
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
|
|
"""Pydantic v2 does post init differently, so we need to override this method too."""
|
|
|
|
logger.info("Freezing models to train only the loras.")
|
|
|
|
self.models["unet"].train = False
|
|
|
|
self.models["text_encoder"].train = False
|
|
|
|
self.models["lda"].train = False
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
config: LoraLatentDiffusionConfig,
|
|
|
|
callbacks: "list[Callback[Any]] | None" = None,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(config=config, callbacks=callbacks)
|
|
|
|
self.callbacks.extend((LoadLoras(), SaveLoras()))
|
|
|
|
|
|
|
|
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
|
|
|
return TriggerPhraseDataset(trainer=self)
|
|
|
|
|
|
|
|
|
|
|
|
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
|
|
|
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
|
|
|
lora_config = trainer.config.lora
|
|
|
|
for target in lora_config.unet_targets:
|
|
|
|
lora_config.apply_loras_to_target(module=trainer.unet, target=target)
|
|
|
|
for target in lora_config.text_encoder_targets:
|
|
|
|
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target)
|
|
|
|
for target in lora_config.lda_targets:
|
|
|
|
lora_config.apply_loras_to_target(module=trainer.lda, target=target)
|
|
|
|
|
|
|
|
|
|
|
|
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
|
|
|
|
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
|
|
|
lora_config = trainer.config.lora
|
|
|
|
|
|
|
|
def get_weight(linear: fl.Linear) -> Tensor:
|
|
|
|
assert linear.bias is None
|
|
|
|
return linear.state_dict()["weight"]
|
|
|
|
|
|
|
|
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
|
|
|
|
weights: list[Tensor] = []
|
|
|
|
for lora in module.layers(layer_type=Lora):
|
|
|
|
linears = list(lora.layers(fl.Linear))
|
|
|
|
assert len(linears) == 2
|
|
|
|
# See `load_lora_weights` in refiners.adapters.lora
|
|
|
|
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
|
|
|
|
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
|
|
|
|
|
|
|
|
tensors: dict[str, Tensor] = {}
|
|
|
|
metadata: dict[str, str] = {}
|
|
|
|
|
|
|
|
if lora_config.unet_targets:
|
|
|
|
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.")
|
|
|
|
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)}
|
|
|
|
|
|
|
|
if lora_config.text_encoder_targets:
|
|
|
|
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
|
|
|
|
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
|
|
|
|
|
|
|
|
if lora_config.lda_targets:
|
|
|
|
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
|
|
|
|
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
|
|
|
|
|
|
|
|
save_to_safetensors(
|
|
|
|
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
|
|
|
|
tensors=tensors,
|
|
|
|
metadata=metadata,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import sys
|
|
|
|
|
|
|
|
config_path = sys.argv[1]
|
|
|
|
config = LoraLatentDiffusionConfig.load_from_toml(
|
|
|
|
toml_path=config_path,
|
|
|
|
)
|
|
|
|
trainer = LoraLatentDiffusionTrainer(config=config)
|
|
|
|
trainer.train()
|