mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
147 lines
5.2 KiB
Python
147 lines
5.2 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
from typing import cast
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn.init import zeros_
|
|
from torch.nn import Parameter as TorchParameter
|
|
|
|
from diffusers import DiffusionPipeline # type: ignore
|
|
|
|
import refiners.fluxion.layers as fl
|
|
from refiners.fluxion.model_converter import ModelConverter
|
|
from refiners.fluxion.utils import save_to_safetensors
|
|
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
|
|
|
|
|
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
|
assert linear.bias is None
|
|
return linear.state_dict()["weight"]
|
|
|
|
|
|
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
|
|
weights: list[torch.Tensor] = []
|
|
for lora in module.layers(layer_type=Lora):
|
|
linears = list(lora.layers(layer_type=fl.Linear))
|
|
assert len(linears) == 2
|
|
weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight)
|
|
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
|
|
|
|
|
|
class Args(argparse.Namespace):
|
|
source_path: str
|
|
base_model: str
|
|
output_file: str
|
|
verbose: bool
|
|
|
|
|
|
@torch.no_grad()
|
|
def process(args: Args) -> None:
|
|
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
|
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
|
|
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
|
|
|
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
|
target = LoraTarget.CrossAttention
|
|
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
|
rank = diffusers_state_dict[
|
|
"mid_block.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.down.weight"
|
|
].shape[0]
|
|
|
|
x = torch.randn(1, 4, 32, 32)
|
|
timestep = torch.tensor(data=[0])
|
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
|
|
|
refiners_model.set_timestep(timestep=timestep)
|
|
refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
|
refiners_args = (x,)
|
|
|
|
diffusers_args = (x, timestep, clip_text_embeddings)
|
|
|
|
converter = ModelConverter(
|
|
source_model=refiners_model, target_model=diffusers_model, skip_output_check=True, verbose=args.verbose
|
|
)
|
|
if not converter.run(
|
|
source_args=refiners_args,
|
|
target_args=diffusers_args,
|
|
):
|
|
raise RuntimeError("Model conversion failed")
|
|
|
|
diffusers_to_refiners = converter.get_mapping()
|
|
|
|
LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
|
|
|
|
for layer in refiners_model.layers(layer_type=Lora):
|
|
zeros_(tensor=layer.Linear_1.weight)
|
|
|
|
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
|
|
for target_k in targets:
|
|
k_p, k_s = target_k.split(".processor.")
|
|
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
|
|
assert len(r) == 1
|
|
orig_k = r[0]
|
|
orig_path = orig_k.split(sep=".")
|
|
p = refiners_model
|
|
for seg in orig_path[:-1]:
|
|
p = p[seg]
|
|
assert isinstance(p, fl.Chain)
|
|
last_seg = (
|
|
"SingleLoraAdapter"
|
|
if orig_path[-1] == "Linear"
|
|
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
|
)
|
|
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
|
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
|
p[last_seg].Lora.load_weights(p_down, p_up)
|
|
|
|
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
|
|
assert len(state_dict) == 320
|
|
save_to_safetensors(path=args.output_path, tensors=state_dict, metadata=metadata)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Convert LoRAs saved using the diffusers library to refiners format.")
|
|
parser.add_argument(
|
|
"--from",
|
|
type=str,
|
|
dest="source_path",
|
|
required=True,
|
|
help="Source file path (.bin|safetensors) containing the LoRAs.",
|
|
)
|
|
parser.add_argument(
|
|
"--base-model",
|
|
type=str,
|
|
required=False,
|
|
default="runwayml/stable-diffusion-v1-5",
|
|
help="Base model, used for the UNet structure. Default: runwayml/stable-diffusion-v1-5",
|
|
)
|
|
parser.add_argument(
|
|
"--to",
|
|
type=str,
|
|
dest="output_path",
|
|
required=False,
|
|
default=None,
|
|
help=(
|
|
"Output file path (.safetensors) for converted LoRAs. If not provided, the output path will be the same as"
|
|
" the source path."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--verbose",
|
|
action="store_true",
|
|
dest="verbose",
|
|
default=False,
|
|
help="Use this flag to print verbose output during conversion.",
|
|
)
|
|
args = parser.parse_args(namespace=Args())
|
|
if args.output_path is None:
|
|
args.output_path = f"{Path(args.source_path).stem}-refiners.safetensors"
|
|
process(args=args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|