diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index 19c07b8..40ad8a5 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -3,7 +3,7 @@ from typing import Any from pydantic import BaseModel from loguru import logger from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS +from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets import refiners.fluxion.layers as fl from torch import Tensor from torch.utils.data import Dataset @@ -79,9 +79,10 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]): for model_name in MODELS: model = getattr(trainer, model_name) + model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets") adapter = LoraAdapter[type(model)]( model, - sub_targets=getattr(lora_config, f"{model_name}_targets"), + sub_targets=[x for target in model_targets for x in lora_targets(model, target)], rank=lora_config.rank, ) for sub_adapter, _ in adapter.sub_adapters: