fix LoRA training script

This commit is contained in:
Pierre Chapuis 2023-09-01 10:10:13 +02:00
parent 9f6733de8e
commit 9cf622a6e2

View file

@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel
from loguru import logger
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
@ -79,9 +79,10 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
for model_name in MODELS:
model = getattr(trainer, model_name)
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
adapter = LoraAdapter[type(model)](
model,
sub_targets=getattr(lora_config, f"{model_name}_targets"),
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters: