mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 22:58:45 +00:00
fix typing for scripts
This commit is contained in:
parent
89224c1e75
commit
c9fba44f39
|
@ -1,13 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import save_file
|
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
||||||
from refiners.fluxion.utils import (
|
|
||||||
create_state_dict_mapping,
|
|
||||||
convert_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel
|
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
|
||||||
|
@ -16,12 +12,15 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = CLIPTextEncoderL()
|
dst_model = CLIPTextEncoderL()
|
||||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||||
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||||
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
assert mapping is not None, "Model conversion failed"
|
||||||
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
)
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -41,9 +40,9 @@ def main():
|
||||||
help="Path for the output file",
|
help="Path for the output file",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_model = DiffusionPipeline.from_pretrained(args.source).text_encoder
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder # type: ignore
|
||||||
tensors = convert(src_model)
|
tensors = convert(src_model=src_model) # type: ignore
|
||||||
save_file(tensors, args.output_file)
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ControlNetModel
|
from diffusers import ControlNetModel # type: ignore
|
||||||
from safetensors.torch import save_file
|
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import (
|
||||||
forward_order_of_execution,
|
forward_order_of_execution,
|
||||||
verify_shape_match,
|
verify_shape_match,
|
||||||
convert_state_dict,
|
convert_state_dict,
|
||||||
|
save_to_safetensors,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
|
@ -16,16 +16,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
controlnet = Controlnet(name="mycn")
|
controlnet = Controlnet(name="mycn")
|
||||||
|
|
||||||
condition = torch.randn(1, 3, 512, 512)
|
condition = torch.randn(1, 3, 512, 512)
|
||||||
controlnet.set_controlnet_condition(condition)
|
controlnet.set_controlnet_condition(condition=condition)
|
||||||
|
|
||||||
unet = UNet(4, clip_embedding_dim=768)
|
unet = UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
unet.insert(0, controlnet)
|
unet.insert(index=0, module=controlnet)
|
||||||
clip_text_embedding = torch.rand(1, 77, 768)
|
clip_text_embedding = torch.rand(1, 77, 768)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding)
|
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
scheduler = DPMSolver(num_inference_steps=10)
|
scheduler = DPMSolver(num_inference_steps=10)
|
||||||
timestep = scheduler.timesteps[0].unsqueeze(0)
|
timestep = scheduler.timesteps[0].unsqueeze(dim=0)
|
||||||
unet.set_timestep(timestep.unsqueeze(0))
|
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
|
||||||
|
|
||||||
x = torch.randn(1, 4, 64, 64)
|
x = torch.randn(1, 4, 64, 64)
|
||||||
|
|
||||||
|
@ -33,8 +33,8 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
# to diffusers in order, since we compute the residuals inline instead of
|
# to diffusers in order, since we compute the residuals inline instead of
|
||||||
# in a separate step.
|
# in a separate step.
|
||||||
|
|
||||||
source_order = forward_order_of_execution(controlnet_src, (x, timestep, clip_text_embedding, condition))
|
source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore
|
||||||
target_order = forward_order_of_execution(controlnet, (x,))
|
target_order = forward_order_of_execution(module=controlnet, example_args=(x,))
|
||||||
|
|
||||||
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
assert target_order[broken_k] == expected_target_order
|
assert target_order[broken_k] == expected_target_order
|
||||||
source_order[broken_k] = fixed_source_order
|
source_order[broken_k] = fixed_source_order
|
||||||
|
|
||||||
assert verify_shape_match(source_order, target_order)
|
assert verify_shape_match(source_order=source_order, target_order=target_order)
|
||||||
|
|
||||||
mapping: dict[str, str] = {}
|
mapping: dict[str, str] = {}
|
||||||
for model_type_shape in source_order:
|
for model_type_shape in source_order:
|
||||||
|
@ -170,12 +170,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
target_keys = target_order[model_type_shape]
|
target_keys = target_order[model_type_shape]
|
||||||
mapping.update(zip(target_keys, source_keys))
|
mapping.update(zip(target_keys, source_keys))
|
||||||
|
|
||||||
state_dict = convert_state_dict(controlnet_src.state_dict(), controlnet.state_dict(), state_dict_mapping=mapping)
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=controlnet_src.state_dict(),
|
||||||
|
target_state_dict=controlnet.state_dict(),
|
||||||
|
state_dict_mapping=mapping,
|
||||||
|
)
|
||||||
|
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -194,9 +198,9 @@ def main():
|
||||||
help="Path for the output file",
|
help="Path for the output file",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
controlnet_src = ControlNetModel.from_pretrained(args.source)
|
controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore
|
||||||
tensors = convert(controlnet_src)
|
tensors = convert(controlnet_src=controlnet_src) # type: ignore
|
||||||
save_file(tensors, args.output_file)
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Note: this conversion script currently only support simple LoRAs which adapt
|
# Note: this conversion script currently only support simple LoRAs which adapt
|
||||||
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
|
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.init import zeros_
|
from torch.nn.init import zeros_
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
@ -13,7 +14,7 @@ from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras
|
||||||
from refiners.adapters.lora import Lora
|
from refiners.adapters.lora import Lora
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping
|
from refiners.fluxion.utils import create_state_dict_mapping
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
||||||
|
@ -24,17 +25,17 @@ def get_weight(linear: fl.Linear) -> torch.Tensor:
|
||||||
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
|
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
|
||||||
weights: list[torch.Tensor] = []
|
weights: list[torch.Tensor] = []
|
||||||
for lora in module.layers(layer_type=Lora):
|
for lora in module.layers(layer_type=Lora):
|
||||||
linears = list(lora.layers(fl.Linear))
|
linears = list(lora.layers(layer_type=fl.Linear))
|
||||||
assert len(linears) == 2
|
assert len(linears) == 2
|
||||||
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
|
weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight)
|
||||||
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
|
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def process(source: str, base_model: str, output_file: str) -> None:
|
def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
|
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained(base_model) # type: ignore
|
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
|
||||||
diffusers_model = diffusers_sd.unet
|
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||||
|
|
||||||
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
|
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
target = LoraTarget.CrossAttention
|
target = LoraTarget.CrossAttention
|
||||||
|
@ -44,21 +45,23 @@ def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
].shape[0]
|
].shape[0]
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor([0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
refiners_model.set_timestep(timestep)
|
refiners_model.set_timestep(timestep=timestep)
|
||||||
refiners_model.set_clip_text_embedding(clip_text_embeddings)
|
refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
refiners_args = (x,)
|
refiners_args = (x,)
|
||||||
|
|
||||||
diffusers_args = (x, timestep, clip_text_embeddings)
|
diffusers_args = (x, timestep, clip_text_embeddings)
|
||||||
|
|
||||||
diffusers_to_refiners = create_state_dict_mapping(refiners_model, diffusers_model, refiners_args, diffusers_args)
|
diffusers_to_refiners = create_state_dict_mapping(
|
||||||
assert diffusers_to_refiners
|
source_model=refiners_model, target_model=diffusers_model, source_args=refiners_args, target_args=diffusers_args
|
||||||
|
)
|
||||||
|
assert diffusers_to_refiners is not None
|
||||||
|
|
||||||
apply_loras_to_target(refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
|
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
|
||||||
for layer in refiners_model.layers(layer_type=Lora):
|
for layer in refiners_model.layers(layer_type=Lora):
|
||||||
zeros_(layer.Linear_1.weight)
|
zeros_(tensor=layer.Linear_1.weight)
|
||||||
|
|
||||||
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
|
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
|
||||||
for target_k in targets:
|
for target_k in targets:
|
||||||
|
@ -66,23 +69,24 @@ def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
|
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
|
||||||
assert len(r) == 1
|
assert len(r) == 1
|
||||||
orig_k = r[0]
|
orig_k = r[0]
|
||||||
orig_path = orig_k.split(".")
|
orig_path = orig_k.split(sep=".")
|
||||||
p = refiners_model
|
p = refiners_model
|
||||||
for seg in orig_path[:-1]:
|
for seg in orig_path[:-1]:
|
||||||
p = p[seg]
|
p = p[seg]
|
||||||
|
assert isinstance(p, fl.Chain)
|
||||||
last_seg = (
|
last_seg = (
|
||||||
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
||||||
)
|
)
|
||||||
p_down = TorchParameter(diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
||||||
p_up = TorchParameter(diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
||||||
p[last_seg].Lora.load_weights(p_down, p_up)
|
p[last_seg].Lora.load_weights(p_down, p_up)
|
||||||
|
|
||||||
state_dict = build_loras_safetensors(refiners_model, key_prefix="unet.")
|
state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
|
||||||
assert len(state_dict) == 320
|
assert len(state_dict) == 320
|
||||||
save_to_safetensors(output_file, tensors=state_dict, metadata=metadata)
|
save_to_safetensors(path=output_file, tensors=state_dict, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors
|
from refiners.fluxion.utils import (
|
||||||
|
load_from_safetensors,
|
||||||
|
load_metadata_from_safetensors,
|
||||||
|
save_to_safetensors,
|
||||||
|
)
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||||
|
@ -8,33 +12,33 @@ from refiners.fluxion.utils import create_state_dict_mapping
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel
|
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
|
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor([0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
src_args = (x, timestep, clip_text_embeddings)
|
src_args = (x, timestep, clip_text_embeddings)
|
||||||
dst_model.set_timestep(timestep)
|
dst_model.set_timestep(timestep=timestep)
|
||||||
dst_model.set_clip_text_embedding(clip_text_embeddings)
|
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
dst_args = (x,)
|
dst_args = (x,)
|
||||||
|
|
||||||
return create_state_dict_mapping(src_model, dst_model, src_args, dst_args) # type: ignore
|
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
|
||||||
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
x = dst_model.tokenizer("Nice cat", sequence_length=77)
|
||||||
|
|
||||||
return create_state_dict_mapping(src_model, dst_model, [x]) # type: ignore
|
return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -61,11 +65,11 @@ def main():
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
metadata = load_metadata_from_safetensors(args.input_file)
|
metadata = load_metadata_from_safetensors(path=args.input_file)
|
||||||
assert metadata is not None
|
assert metadata is not None
|
||||||
tensors = load_from_safetensors(args.input_file)
|
tensors = load_from_safetensors(path=args.input_file)
|
||||||
|
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained(args.sd15) # type: ignore
|
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.sd15) # type: ignore
|
||||||
|
|
||||||
state_dict: dict[str, torch.Tensor] = {}
|
state_dict: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
@ -110,16 +114,16 @@ def main():
|
||||||
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
# Compute the corresponding diffusers' keys where LoRA layers must be applied
|
||||||
lora_injection_points: list[str] = [
|
lora_injection_points: list[str] = [
|
||||||
refiners_to_diffusers[submodule_to_key[linear]]
|
refiners_to_diffusers[submodule_to_key[linear]]
|
||||||
for target in [LoraTarget(t) for t in meta_value.split(",")]
|
for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
|
||||||
for layer in dst_model.layers(layer_type=target.get_class())
|
for layer in dst_model.layers(layer_type=target.get_class())
|
||||||
for linear in layer.layers(fl.Linear)
|
for linear in layer.layers(layer_type=fl.Linear)
|
||||||
]
|
]
|
||||||
|
|
||||||
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
|
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
|
||||||
assert len(lora_injection_points) == len(lora_weights) // 2
|
assert len(lora_injection_points) == len(lora_weights) // 2
|
||||||
|
|
||||||
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
|
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
|
||||||
for i, diffusers_key in enumerate(lora_injection_points):
|
for i, diffusers_key in enumerate(iterable=lora_injection_points):
|
||||||
lora_key = lora_prefix + diffusers_key.replace(".", "_")
|
lora_key = lora_prefix + diffusers_key.replace(".", "_")
|
||||||
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
|
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
|
||||||
# by default (see `lora_calc_updown` in SD-WebUI for more details)
|
# by default (see `lora_calc_updown` in SD-WebUI for more details)
|
||||||
|
@ -127,7 +131,7 @@ def main():
|
||||||
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
|
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
|
||||||
|
|
||||||
assert state_dict
|
assert state_dict
|
||||||
save_to_safetensors(args.output_file, state_dict)
|
save_to_safetensors(path=args.output_file, tensors=state_dict)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import (
|
||||||
create_state_dict_mapping,
|
create_state_dict_mapping,
|
||||||
convert_state_dict,
|
convert_state_dict,
|
||||||
|
save_to_safetensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from diffusers.models.autoencoder_kl import AutoencoderKL
|
from diffusers.models.autoencoder_kl import AutoencoderKL # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
|
||||||
|
@ -16,12 +16,15 @@ from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusion
|
||||||
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
|
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
|
||||||
dst_model = LatentDiffusionAutoencoder()
|
dst_model = LatentDiffusionAutoencoder()
|
||||||
x = torch.randn(1, 3, 512, 512)
|
x = torch.randn(1, 3, 512, 512)
|
||||||
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||||
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
assert mapping is not None, "Model conversion failed"
|
||||||
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
)
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -41,9 +44,9 @@ def main():
|
||||||
help="Path for the output file",
|
help="Path for the output file",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_model = DiffusionPipeline.from_pretrained(args.source).vae
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).vae # type: ignore
|
||||||
tensors = convert(src_model)
|
tensors = convert(src_model=src_model) # type: ignore
|
||||||
save_file(tensors, args.output_file)
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,13 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import save_file
|
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
||||||
from refiners.fluxion.utils import (
|
|
||||||
create_state_dict_mapping,
|
|
||||||
convert_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers import StableDiffusionInpaintPipeline
|
from diffusers import StableDiffusionInpaintPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
|
||||||
|
@ -17,20 +13,23 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
|
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
|
|
||||||
x = torch.randn(1, 9, 32, 32)
|
x = torch.randn(1, 9, 32, 32)
|
||||||
timestep = torch.tensor([0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
src_args = (x, timestep, clip_text_embeddings)
|
src_args = (x, timestep, clip_text_embeddings)
|
||||||
dst_model.set_timestep(timestep)
|
dst_model.set_timestep(timestep=timestep)
|
||||||
dst_model.set_clip_text_embedding(clip_text_embeddings)
|
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
dst_args = (x,)
|
dst_args = (x,)
|
||||||
|
|
||||||
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args)
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
|
||||||
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
assert mapping is not None, "Model conversion failed"
|
||||||
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
)
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -50,9 +49,9 @@ def main():
|
||||||
help="Path for the output file",
|
help="Path for the output file",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_model = StableDiffusionInpaintPipeline.from_pretrained(args.source).unet
|
src_model = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
||||||
tensors = convert(src_model)
|
tensors = convert(src_model=src_model) # type: ignore
|
||||||
save_file(tensors, args.output_file)
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,13 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import save_file
|
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
||||||
from refiners.fluxion.utils import (
|
|
||||||
create_state_dict_mapping,
|
|
||||||
convert_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||||
|
|
||||||
|
@ -17,20 +13,23 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor([0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
|
||||||
src_args = (x, timestep, clip_text_embeddings)
|
src_args = (x, timestep, clip_text_embeddings)
|
||||||
dst_model.set_timestep(timestep)
|
dst_model.set_timestep(timestep=timestep)
|
||||||
dst_model.set_clip_text_embedding(clip_text_embeddings)
|
dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
|
||||||
dst_args = (x,)
|
dst_args = (x,)
|
||||||
|
|
||||||
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args)
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
|
||||||
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
assert mapping is not None, "Model conversion failed"
|
||||||
|
state_dict = convert_state_dict(
|
||||||
|
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
|
||||||
|
)
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -50,9 +49,9 @@ def main():
|
||||||
help="Path for the output file",
|
help="Path for the output file",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_model = DiffusionPipeline.from_pretrained(args.source).unet
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
||||||
tensors = convert(src_model)
|
tensors = convert(src_model=src_model) # type: ignore
|
||||||
save_file(tensors, args.output_file)
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import save_file # type: ignore
|
from safetensors.torch import save_file # type: ignore
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
||||||
create_state_dict_mapping,
|
|
||||||
convert_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
@ -28,7 +25,7 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
|
||||||
return {k: v.half() for k, v in state_dict.items()}
|
return {k: v.half() for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -49,7 +46,7 @@ def main():
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
|
||||||
tensors = convert(src_model=src_model)
|
tensors = convert(src_model=src_model) # type: ignore
|
||||||
save_file(tensors=tensors, filename=args.output_file)
|
save_file(tensors=tensors, filename=args.output_file)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from safetensors.torch import save_file # type: ignore
|
from safetensors.torch import save_file # type: ignore
|
||||||
from refiners.fluxion.utils import (
|
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
||||||
create_state_dict_mapping,
|
|
||||||
convert_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
@ -17,7 +14,7 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = SDXLUNet(in_channels=4)
|
dst_model = SDXLUNet(in_channels=4)
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor([0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 2048)
|
clip_text_embeddings = torch.randn(1, 77, 2048)
|
||||||
|
|
||||||
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
|
||||||
|
@ -60,8 +57,8 @@ def main() -> None:
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
|
||||||
tensors = convert(src_model)
|
tensors = convert(src_model=src_model) # type: ignore
|
||||||
save_file(tensors, args.output_file)
|
save_file(tensors=tensors, filename=args.output_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue