mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
import random
|
|
from typing import Any
|
|
from pydantic import BaseModel
|
|
from loguru import logger
|
|
from refiners.fluxion.utils import save_to_safetensors
|
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets
|
|
import refiners.fluxion.layers as fl
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset
|
|
|
|
from refiners.training_utils.callback import Callback
|
|
from refiners.training_utils.latent_diffusion import (
|
|
FinetuneLatentDiffusionConfig,
|
|
TextEmbeddingLatentsBatch,
|
|
TextEmbeddingLatentsDataset,
|
|
LatentDiffusionTrainer,
|
|
LatentDiffusionConfig,
|
|
)
|
|
|
|
|
|
class LoraConfig(BaseModel):
|
|
rank: int = 32
|
|
trigger_phrase: str = ""
|
|
use_only_trigger_probability: float = 0.0
|
|
unet_targets: list[LoraTarget]
|
|
text_encoder_targets: list[LoraTarget]
|
|
lda_targets: list[LoraTarget]
|
|
|
|
|
|
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
|
def __init__(
|
|
self,
|
|
trainer: "LoraLatentDiffusionTrainer",
|
|
) -> None:
|
|
super().__init__(trainer=trainer)
|
|
self.trigger_phrase = trainer.config.lora.trigger_phrase
|
|
self.use_only_trigger_probability = trainer.config.lora.use_only_trigger_probability
|
|
logger.info(f"Trigger phrase: {self.trigger_phrase}")
|
|
|
|
def process_caption(self, caption: str) -> str:
|
|
caption = super().process_caption(caption=caption)
|
|
if self.trigger_phrase:
|
|
caption = (
|
|
f"{self.trigger_phrase} {caption}"
|
|
if random.random() < self.use_only_trigger_probability
|
|
else self.trigger_phrase
|
|
)
|
|
return caption
|
|
|
|
|
|
class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
|
|
latent_diffusion: LatentDiffusionConfig
|
|
lora: LoraConfig
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""Pydantic v2 does post init differently, so we need to override this method too."""
|
|
logger.info("Freezing models to train only the loras.")
|
|
self.models["unet"].train = False
|
|
self.models["text_encoder"].train = False
|
|
self.models["lda"].train = False
|
|
|
|
|
|
class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfig]):
|
|
def __init__(
|
|
self,
|
|
config: LoraLatentDiffusionConfig,
|
|
callbacks: "list[Callback[Any]] | None" = None,
|
|
) -> None:
|
|
super().__init__(config=config, callbacks=callbacks)
|
|
self.callbacks.extend((LoadLoras(), SaveLoras()))
|
|
|
|
def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
|
|
return TriggerPhraseDataset(trainer=self)
|
|
|
|
|
|
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
|
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
|
lora_config = trainer.config.lora
|
|
|
|
for model_name in MODELS:
|
|
model = getattr(trainer, model_name)
|
|
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
|
adapter = LoraAdapter[type(model)](
|
|
model,
|
|
sub_targets=lora_targets(model, model_targets),
|
|
rank=lora_config.rank,
|
|
)
|
|
for sub_adapter, _ in adapter.sub_adapters:
|
|
for linear in sub_adapter.Lora.layers(fl.Linear):
|
|
linear.requires_grad_(requires_grad=True)
|
|
adapter.inject()
|
|
|
|
|
|
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
|
|
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
|
tensors: dict[str, Tensor] = {}
|
|
metadata: dict[str, str] = {}
|
|
|
|
for model_name in MODELS:
|
|
model = getattr(trainer, model_name)
|
|
adapter = model.parent
|
|
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
|
|
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
|
|
|
|
save_to_safetensors(
|
|
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
|
|
tensors=tensors,
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
config_path = sys.argv[1]
|
|
config = LoraLatentDiffusionConfig.load_from_toml(
|
|
toml_path=config_path,
|
|
)
|
|
trainer = LoraLatentDiffusionTrainer(config=config)
|
|
trainer.train()
|