refiners/scripts/training/finetune-ldm-lora.py
2023-12-20 16:58:12 +01:00

124 lines
4.3 KiB
Python

import random
from typing import Any
from loguru import logger
from pydantic import BaseModel
from torch import Tensor
from torch.utils.data import Dataset
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
LatentDiffusionTrainer,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
)
class LoraConfig(BaseModel):
rank: int = 32
trigger_phrase: str = ""
use_only_trigger_probability: float = 0.0
unet_targets: list[LoraTarget]
text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget]
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__(
self,
trainer: "LoraLatentDiffusionTrainer",
) -> None:
super().__init__(trainer=trainer)
self.trigger_phrase = trainer.config.lora.trigger_phrase
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
logger.info(f"Trigger phrase: {self.trigger_phrase}")
def process_caption(self, caption: str) -> str:
caption = super().process_caption(caption=caption)
if self.trigger_phrase:
caption = (
f"{self.trigger_phrase} {caption}"
if random.random() < self.use_only_trigger_probability
else self.trigger_phrase
)
return caption
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig
def model_post_init(self, __context: Any) -> None:
"""Pydantic v2 does post init differently, so we need to override this method too."""
logger.info("Freezing models to train only the loras.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
def __init__(
self,
config: LoraLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadLoras(), SaveLoras()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TriggerPhraseDataset(trainer=self)
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
for model_name in MODELS:
model = getattr(trainer, model_name)
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
adapter = LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:
for linear in sub_adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
adapter.inject()
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {}
for model_name in MODELS:
model = getattr(trainer, model_name)
adapter = model.parent
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
tensors=tensors,
metadata=metadata,
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = LoraLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LoraLatentDiffusionTrainer(config=config)
trainer.train()