refiners/scripts/training/finetune-ldm-lora.py

120 lines
4.1 KiB
Python
Raw Normal View History

2023-08-04 13:28:41 +00:00
import random
from typing import Any
from pydantic import BaseModel
from loguru import logger
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
2023-08-04 13:28:41 +00:00
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
from refiners.training_utils.callback import Callback
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
TextEmbeddingLatentsBatch,
TextEmbeddingLatentsDataset,
LatentDiffusionTrainer,
LatentDiffusionConfig,
)
class LoraConfig(BaseModel):
rank: int = 32
trigger_phrase: str = ""
use_only_trigger_probability: float = 0.0
unet_targets: list[LoraTarget]
text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget]
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__(
self,
trainer: "LoraLatentDiffusionTrainer",
) -> None:
super().__init__(trainer=trainer)
self.trigger_phrase = trainer.config.lora.trigger_phrase
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
logger.info(f"Trigger phrase: {self.trigger_phrase}")
def process_caption(self, caption: str) -> str:
caption = super().process_caption(caption=caption)
if self.trigger_phrase:
caption = (
f"{self.trigger_phrase} {caption}"
if random.random() < self.use_only_trigger_probability
else self.trigger_phrase
)
return caption
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig
def model_post_init(self, __context: Any) -> None:
"""Pydantic v2 does post init differently, so we need to override this method too."""
logger.info("Freezing models to train only the loras.")
self.models["unet"].train = False
self.models["text_encoder"].train = False
self.models["lda"].train = False
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
def __init__(
self,
config: LoraLatentDiffusionConfig,
callbacks: "list[Callback[Any]] | None" = None,
) -> None:
super().__init__(config=config, callbacks=callbacks)
self.callbacks.extend((LoadLoras(), SaveLoras()))
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
return TriggerPhraseDataset(trainer=self)
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
for model_name in MODELS:
model = getattr(trainer, model_name)
adapter = LoraAdapter[type(model)](
model,
sub_targets=getattr(lora_config, f"{model_name}_targets"),
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:
for linear in sub_adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
adapter.inject()
2023-08-04 13:28:41 +00:00
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {}
for model_name in MODELS:
model = getattr(trainer, model_name)
adapter = model.parent
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
2023-08-04 13:28:41 +00:00
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
tensors=tensors,
metadata=metadata,
)
if __name__ == "__main__":
import sys
config_path = sys.argv[1]
config = LoraLatentDiffusionConfig.load_from_toml(
toml_path=config_path,
)
trainer = LoraLatentDiffusionTrainer(config=config)
trainer.train()