mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
fix LoRA training script
This commit is contained in:
parent
9f6733de8e
commit
9cf622a6e2
|
@ -3,7 +3,7 @@ from typing import Any
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
@ -79,9 +79,10 @@ class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||||
|
|
||||||
for model_name in MODELS:
|
for model_name in MODELS:
|
||||||
model = getattr(trainer, model_name)
|
model = getattr(trainer, model_name)
|
||||||
|
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
|
||||||
adapter = LoraAdapter[type(model)](
|
adapter = LoraAdapter[type(model)](
|
||||||
model,
|
model,
|
||||||
sub_targets=getattr(lora_config, f"{model_name}_targets"),
|
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
||||||
rank=lora_config.rank,
|
rank=lora_config.rank,
|
||||||
)
|
)
|
||||||
for sub_adapter, _ in adapter.sub_adapters:
|
for sub_adapter, _ in adapter.sub_adapters:
|
||||||
|
|
Loading…
Reference in a new issue