refiners/scripts/training/finetune-ldm-lora.py

143 lines
5.4 KiB
Python
Raw Normal View History

2023-08-04 13:28:41 +00:00
import random
from typing import Any
from pydantic import BaseModel
from loguru import logger
from refiners.adapters.lora import LoraAdapter, Lora
from refiners.fluxion.utils import save_to_safetensors
2023-08-22 16:03:26 +00:00
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
2023-08-04 13:28:41 +00:00
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
from refiners.training_utils.callback import Callback
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
LatentDiffusionTrainer,
LatentDiffusionConfig,
)
class LoraConfig(BaseModel):
rank: int = 32
trigger_phrase: str = ""
use_only_trigger_probability: float = 0.0
unet_targets: list[LoraTarget]
text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget]
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
2023-08-22 16:03:26 +00:00
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=self.rank)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
2023-08-04 13:28:41 +00:00
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__(
self,
trainer: "LoraLatentDiffusionTrainer",
) -> None:
super().__init__(trainer=trainer)
self.trigger_phrase = trainer.config.lora.trigger_phrase
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
logger.info(f"Trigger phrase: {self.trigger_phrase}")
def process_caption(self, caption: str) -> str:
caption = super().process_caption(caption=caption)
if self.trigger_phrase:
caption = (
f"{self.trigger_phrase} {caption}"
if random.random() < self.use_only_trigger_probability
else self.trigger_phrase
)
return caption
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig
def model_post_init(self, __context: Any) -> None:
"""Pydantic v2 does post init differently, so we need to override this method too."""
logger.info("Freezing models to train only the loras.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
def __init__(
self,
config: LoraLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadLoras(), SaveLoras()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TriggerPhraseDataset(trainer=self)
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
for target in lora_config.unet_targets:
lora_config.apply_loras_to_target(module=trainer.unet, target=target)
for target in lora_config.text_encoder_targets:
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target)
for target in lora_config.lda_targets:
lora_config.apply_loras_to_target(module=trainer.lda, target=target)
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
def get_weight(linear: fl.Linear) -> Tensor:
assert linear.bias is None
return linear.state_dict()["weight"]
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
weights: list[Tensor] = []
for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(fl.Linear))
assert len(linears) == 2
# See `load_lora_weights` in refiners.adapters.lora
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {}
if lora_config.unet_targets:
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.")
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)}
if lora_config.text_encoder_targets:
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
if lora_config.lda_targets:
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
tensors=tensors,
metadata=metadata,
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = LoraLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LoraLatentDiffusionTrainer(config=config)
trainer.train()