mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix LoRA training script
This commit is contained in:
parent
9f6733de8e
commit
9cf622a6e2
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets
|
||||
import refiners.fluxion.layers as fl
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -79,9 +79,10 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
|||
|
||||
for model_name in MODELS:
|
||||
model = getattr(trainer, model_name)
|
||||
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
||||
adapter = LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=getattr(lora_config, f"{model_name}_targets"),
|
||||
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
||||
rank=lora_config.rank,
|
||||
)
|
||||
for sub_adapter, _ in adapter.sub_adapters:
|
||||
|
|
Loading…
Reference in a new issue