fix typing for scripts

This commit is contained in:
limiteinductive 2023-08-15 14:35:17 +02:00 committed by Benjamin Trom
parent 89224c1e75
commit c9fba44f39
9 changed files with 125 additions and 119 deletions

View file

@ -1,13 +1,9 @@
import torch import torch
from safetensors.torch import save_file from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
@ -16,12 +12,15 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
dst_model = CLIPTextEncoderL() dst_model = CLIPTextEncoderL()
x = dst_model.tokenizer("Nice cat", sequence_length=77) x = dst_model.tokenizer("Nice cat", sequence_length=77)
mapping = create_state_dict_mapping(src_model, dst_model, [x]) mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()} return {k: v.half() for k, v in state_dict.items()}
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -41,9 +40,9 @@ def main():
help="Path for the output file", help="Path for the output file",
) )
args = parser.parse_args() args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(args.source).text_encoder src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder # type: ignore
tensors = convert(src_model) tensors = convert(src_model=src_model) # type: ignore
save_file(tensors, args.output_file) save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,10 +1,10 @@
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel # type: ignore
from safetensors.torch import save_file
from refiners.fluxion.utils import ( from refiners.fluxion.utils import (
forward_order_of_execution, forward_order_of_execution,
verify_shape_match, verify_shape_match,
convert_state_dict, convert_state_dict,
save_to_safetensors,
) )
from refiners.foundationals.latent_diffusion.controlnet import Controlnet from refiners.foundationals.latent_diffusion.controlnet import Controlnet
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
@ -16,16 +16,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
controlnet = Controlnet(name="mycn") controlnet = Controlnet(name="mycn")
condition = torch.randn(1, 3, 512, 512) condition = torch.randn(1, 3, 512, 512)
controlnet.set_controlnet_condition(condition) controlnet.set_controlnet_condition(condition=condition)
unet = UNet(4, clip_embedding_dim=768) unet = UNet(in_channels=4, clip_embedding_dim=768)
unet.insert(0, controlnet) unet.insert(index=0, module=controlnet)
clip_text_embedding = torch.rand(1, 77, 768) clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding) unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
scheduler = DPMSolver(num_inference_steps=10) scheduler = DPMSolver(num_inference_steps=10)
timestep = scheduler.timesteps[0].unsqueeze(0) timestep = scheduler.timesteps[0].unsqueeze(dim=0)
unet.set_timestep(timestep.unsqueeze(0)) unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
x = torch.randn(1, 4, 64, 64) x = torch.randn(1, 4, 64, 64)
@ -33,8 +33,8 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
# to diffusers in order, since we compute the residuals inline instead of # to diffusers in order, since we compute the residuals inline instead of
# in a separate step. # in a separate step.
source_order = forward_order_of_execution(controlnet_src, (x, timestep, clip_text_embedding, condition)) source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore
target_order = forward_order_of_execution(controlnet, (x,)) target_order = forward_order_of_execution(module=controlnet, example_args=(x,))
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320]))) broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
@ -162,7 +162,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
assert target_order[broken_k] == expected_target_order assert target_order[broken_k] == expected_target_order
source_order[broken_k] = fixed_source_order source_order[broken_k] = fixed_source_order
assert verify_shape_match(source_order, target_order) assert verify_shape_match(source_order=source_order, target_order=target_order)
mapping: dict[str, str] = {} mapping: dict[str, str] = {}
for model_type_shape in source_order: for model_type_shape in source_order:
@ -170,12 +170,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
target_keys = target_order[model_type_shape] target_keys = target_order[model_type_shape]
mapping.update(zip(target_keys, source_keys)) mapping.update(zip(target_keys, source_keys))
state_dict = convert_state_dict(controlnet_src.state_dict(), controlnet.state_dict(), state_dict_mapping=mapping) state_dict = convert_state_dict(
source_state_dict=controlnet_src.state_dict(),
target_state_dict=controlnet.state_dict(),
state_dict_mapping=mapping,
)
return {k: v.half() for k, v in state_dict.items()} return {k: v.half() for k, v in state_dict.items()}
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -194,9 +198,9 @@ def main():
help="Path for the output file", help="Path for the output file",
) )
args = parser.parse_args() args = parser.parse_args()
controlnet_src = ControlNetModel.from_pretrained(args.source) controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore
tensors = convert(controlnet_src) tensors = convert(controlnet_src=controlnet_src) # type: ignore
save_file(tensors, args.output_file) save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,6 +1,7 @@
# Note: this conversion script currently only support simple LoRAs which adapt # Note: this conversion script currently only support simple LoRAs which adapt
# the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora # the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora
from typing import cast
import torch import torch
from torch.nn.init import zeros_ from torch.nn.init import zeros_
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
@ -13,7 +14,7 @@ from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras
from refiners.adapters.lora import Lora from refiners.adapters.lora import Lora
from refiners.fluxion.utils import create_state_dict_mapping from refiners.fluxion.utils import create_state_dict_mapping
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline # type: ignore
def get_weight(linear: fl.Linear) -> torch.Tensor: def get_weight(linear: fl.Linear) -> torch.Tensor:
@ -24,17 +25,17 @@ def get_weight(linear: fl.Linear) -> torch.Tensor:
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]: def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]:
weights: list[torch.Tensor] = [] weights: list[torch.Tensor] = []
for lora in module.layers(layer_type=Lora): for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(fl.Linear)) linears = list(lora.layers(layer_type=fl.Linear))
assert len(linears) == 2 assert len(linears) == 2
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight) weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)} return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)}
@torch.no_grad() @torch.no_grad()
def process(source: str, base_model: str, output_file: str) -> None: def process(source: str, base_model: str, output_file: str) -> None:
diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore
diffusers_sd = DiffusionPipeline.from_pretrained(base_model) # type: ignore diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
diffusers_model = diffusers_sd.unet diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
refiners_model = UNet(in_channels=4, clip_embedding_dim=768) refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
target = LoraTarget.CrossAttention target = LoraTarget.CrossAttention
@ -44,21 +45,23 @@ def process(source: str, base_model: str, output_file: str) -> None:
].shape[0] ].shape[0]
x = torch.randn(1, 4, 32, 32) x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0]) timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768) clip_text_embeddings = torch.randn(1, 77, 768)
refiners_model.set_timestep(timestep) refiners_model.set_timestep(timestep=timestep)
refiners_model.set_clip_text_embedding(clip_text_embeddings) refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
refiners_args = (x,) refiners_args = (x,)
diffusers_args = (x, timestep, clip_text_embeddings) diffusers_args = (x, timestep, clip_text_embeddings)
diffusers_to_refiners = create_state_dict_mapping(refiners_model, diffusers_model, refiners_args, diffusers_args) diffusers_to_refiners = create_state_dict_mapping(
assert diffusers_to_refiners source_model=refiners_model, target_model=diffusers_model, source_args=refiners_args, target_args=diffusers_args
)
assert diffusers_to_refiners is not None
apply_loras_to_target(refiners_model, target=LoraTarget(target), rank=rank, scale=1.0) apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
for layer in refiners_model.layers(layer_type=Lora): for layer in refiners_model.layers(layer_type=Lora):
zeros_(layer.Linear_1.weight) zeros_(tensor=layer.Linear_1.weight)
targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()} targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()}
for target_k in targets: for target_k in targets:
@ -66,23 +69,24 @@ def process(source: str, base_model: str, output_file: str) -> None:
r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")] r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")]
assert len(r) == 1 assert len(r) == 1
orig_k = r[0] orig_k = r[0]
orig_path = orig_k.split(".") orig_path = orig_k.split(sep=".")
p = refiners_model p = refiners_model
for seg in orig_path[:-1]: for seg in orig_path[:-1]:
p = p[seg] p = p[seg]
assert isinstance(p, fl.Chain)
last_seg = ( last_seg = (
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}" "LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
) )
p_down = TorchParameter(diffusers_state_dict[f"{target_k}_lora.down.weight"]) p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
p_up = TorchParameter(diffusers_state_dict[f"{target_k}_lora.up.weight"]) p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
p[last_seg].Lora.load_weights(p_down, p_up) p[last_seg].Lora.load_weights(p_down, p_up)
state_dict = build_loras_safetensors(refiners_model, key_prefix="unet.") state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.")
assert len(state_dict) == 320 assert len(state_dict) == 320
save_to_safetensors(output_file, tensors=state_dict, metadata=metadata) save_to_safetensors(path=output_file, tensors=state_dict, metadata=metadata)
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View file

@ -1,4 +1,8 @@
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors from refiners.fluxion.utils import (
load_from_safetensors,
load_metadata_from_safetensors,
save_to_safetensors,
)
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.unet import UNet from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget from refiners.foundationals.latent_diffusion.lora import LoraTarget
@ -8,33 +12,33 @@ from refiners.fluxion.utils import create_state_dict_mapping
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
@torch.no_grad() @torch.no_grad()
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None: def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
x = torch.randn(1, 4, 32, 32) x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0]) timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768) clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings) src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep) dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embeddings) dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_args = (x,) dst_args = (x,)
return create_state_dict_mapping(src_model, dst_model, src_args, dst_args) # type: ignore return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
@torch.no_grad() @torch.no_grad()
def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None: def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None:
x = dst_model.tokenizer("Nice cat", sequence_length=77) x = dst_model.tokenizer("Nice cat", sequence_length=77)
return create_state_dict_mapping(src_model, dst_model, [x]) # type: ignore return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -61,11 +65,11 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
metadata = load_metadata_from_safetensors(args.input_file) metadata = load_metadata_from_safetensors(path=args.input_file)
assert metadata is not None assert metadata is not None
tensors = load_from_safetensors(args.input_file) tensors = load_from_safetensors(path=args.input_file)
diffusers_sd = DiffusionPipeline.from_pretrained(args.sd15) # type: ignore diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.sd15) # type: ignore
state_dict: dict[str, torch.Tensor] = {} state_dict: dict[str, torch.Tensor] = {}
@ -110,16 +114,16 @@ def main():
# Compute the corresponding diffusers' keys where LoRA layers must be applied # Compute the corresponding diffusers' keys where LoRA layers must be applied
lora_injection_points: list[str] = [ lora_injection_points: list[str] = [
refiners_to_diffusers[submodule_to_key[linear]] refiners_to_diffusers[submodule_to_key[linear]]
for target in [LoraTarget(t) for t in meta_value.split(",")] for target in [LoraTarget(t) for t in meta_value.split(sep=",")]
for layer in dst_model.layers(layer_type=target.get_class()) for layer in dst_model.layers(layer_type=target.get_class())
for linear in layer.layers(fl.Linear) for linear in layer.layers(layer_type=fl.Linear)
] ]
lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]] lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]]
assert len(lora_injection_points) == len(lora_weights) // 2 assert len(lora_injection_points) == len(lora_weights) // 2
# Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores) # Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores)
for i, diffusers_key in enumerate(lora_injection_points): for i, diffusers_key in enumerate(iterable=lora_injection_points):
lora_key = lora_prefix + diffusers_key.replace(".", "_") lora_key = lora_prefix + diffusers_key.replace(".", "_")
# Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0 # Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0
# by default (see `lora_calc_updown` in SD-WebUI for more details) # by default (see `lora_calc_updown` in SD-WebUI for more details)
@ -127,7 +131,7 @@ def main():
state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1] state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1]
assert state_dict assert state_dict
save_to_safetensors(args.output_file, state_dict) save_to_safetensors(path=args.output_file, tensors=state_dict)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,13 +1,13 @@
import torch import torch
from safetensors.torch import save_file
from refiners.fluxion.utils import ( from refiners.fluxion.utils import (
create_state_dict_mapping, create_state_dict_mapping,
convert_state_dict, convert_state_dict,
save_to_safetensors,
) )
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoder_kl import AutoencoderKL # type: ignore
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@ -16,12 +16,15 @@ from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusion
def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]: def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]:
dst_model = LatentDiffusionAutoencoder() dst_model = LatentDiffusionAutoencoder()
x = torch.randn(1, 3, 512, 512) x = torch.randn(1, 3, 512, 512)
mapping = create_state_dict_mapping(src_model, dst_model, [x]) mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()} return {k: v.half() for k, v in state_dict.items()}
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -41,9 +44,9 @@ def main():
help="Path for the output file", help="Path for the output file",
) )
args = parser.parse_args() args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(args.source).vae src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).vae # type: ignore
tensors = convert(src_model) tensors = convert(src_model=src_model) # type: ignore
save_file(tensors, args.output_file) save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,13 +1,9 @@
import torch import torch
from safetensors.torch import save_file from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import StableDiffusionInpaintPipeline from diffusers import StableDiffusionInpaintPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.unet import UNet from refiners.foundationals.latent_diffusion.unet import UNet
@ -17,20 +13,23 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = UNet(in_channels=9, clip_embedding_dim=768) dst_model = UNet(in_channels=9, clip_embedding_dim=768)
x = torch.randn(1, 9, 32, 32) x = torch.randn(1, 9, 32, 32)
timestep = torch.tensor([0]) timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768) clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings) src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep) dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embeddings) dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_args = (x,) dst_args = (x,)
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args) mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()} return {k: v.half() for k, v in state_dict.items()}
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -50,9 +49,9 @@ def main():
help="Path for the output file", help="Path for the output file",
) )
args = parser.parse_args() args = parser.parse_args()
src_model = StableDiffusionInpaintPipeline.from_pretrained(args.source).unet src_model = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model) tensors = convert(src_model=src_model) # type: ignore
save_file(tensors, args.output_file) save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,13 +1,9 @@
import torch import torch
from safetensors.torch import save_file from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
from refiners.fluxion.utils import (
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.unet import UNet from refiners.foundationals.latent_diffusion.unet import UNet
@ -17,20 +13,23 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = UNet(in_channels=4, clip_embedding_dim=768) dst_model = UNet(in_channels=4, clip_embedding_dim=768)
x = torch.randn(1, 4, 32, 32) x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0]) timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768) clip_text_embeddings = torch.randn(1, 77, 768)
src_args = (x, timestep, clip_text_embeddings) src_args = (x, timestep, clip_text_embeddings)
dst_model.set_timestep(timestep) dst_model.set_timestep(timestep=timestep)
dst_model.set_clip_text_embedding(clip_text_embeddings) dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings)
dst_args = (x,) dst_args = (x,)
mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args) mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) assert mapping is not None, "Model conversion failed"
state_dict = convert_state_dict(
source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping
)
return {k: v.half() for k, v in state_dict.items()} return {k: v.half() for k, v in state_dict.items()}
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -50,9 +49,9 @@ def main():
help="Path for the output file", help="Path for the output file",
) )
args = parser.parse_args() args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(args.source).unet src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model) tensors = convert(src_model=src_model) # type: ignore
save_file(tensors, args.output_file) save_to_safetensors(path=args.output_file, tensors=tensors)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,10 +1,7 @@
import torch import torch
from safetensors.torch import save_file # type: ignore from safetensors.torch import save_file # type: ignore
from refiners.fluxion.utils import ( from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline # type: ignore from diffusers import DiffusionPipeline # type: ignore
from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
@ -28,7 +25,7 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]:
return {k: v.half() for k, v in state_dict.items()} return {k: v.half() for k, v in state_dict.items()}
def main(): def main() -> None:
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -49,7 +46,7 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore
tensors = convert(src_model=src_model) tensors = convert(src_model=src_model) # type: ignore
save_file(tensors=tensors, filename=args.output_file) save_file(tensors=tensors, filename=args.output_file)

View file

@ -1,10 +1,7 @@
import torch import torch
from safetensors.torch import save_file # type: ignore from safetensors.torch import save_file # type: ignore
from refiners.fluxion.utils import ( from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
create_state_dict_mapping,
convert_state_dict,
)
from diffusers import DiffusionPipeline # type: ignore from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
@ -17,7 +14,7 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = SDXLUNet(in_channels=4) dst_model = SDXLUNet(in_channels=4)
x = torch.randn(1, 4, 32, 32) x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor([0]) timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 2048) clip_text_embeddings = torch.randn(1, 77, 2048)
added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)}
@ -60,8 +57,8 @@ def main() -> None:
) )
args = parser.parse_args() args = parser.parse_args()
src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore
tensors = convert(src_model) tensors = convert(src_model=src_model) # type: ignore
save_file(tensors, args.output_file) save_file(tensors=tensors, filename=args.output_file)
if __name__ == "__main__": if __name__ == "__main__":