From c9fba44f39bf4add33217c1a5ab20a9db8099893 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Tue, 15 Aug 2023 14:35:17 +0200 Subject: [PATCH] fix typing for scripts --- scripts/convert-clip-weights.py | 25 ++++++----- scripts/convert-controlnet-weights.py | 36 +++++++++------- scripts/convert-lora-weights.py | 42 ++++++++++--------- scripts/convert-loras-to-sdwebui.py | 38 +++++++++-------- scripts/convert-sd-lda-weights.py | 21 ++++++---- scripts/convert-sd-unet-inpainting-weights.py | 31 +++++++------- scripts/convert-sd-unet-weights.py | 31 +++++++------- scripts/convert-sdxl-text-encoder-2.py | 9 ++-- scripts/convert-sdxl-unet-weights.py | 11 ++--- 9 files changed, 125 insertions(+), 119 deletions(-) diff --git a/scripts/convert-clip-weights.py b/scripts/convert-clip-weights.py index e4fd93e..72eb760 100644 --- a/scripts/convert-clip-weights.py +++ b/scripts/convert-clip-weights.py @@ -1,13 +1,9 @@ import torch -from safetensors.torch import save_file -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, -) +from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors -from diffusers import DiffusionPipeline -from transformers.models.clip.modeling_clip import CLIPTextModel +from diffusers import DiffusionPipeline # type: ignore +from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL @@ -16,12 +12,15 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: dst_model = CLIPTextEncoderL() x = dst_model.tokenizer("Nice cat", sequence_length=77) - mapping = create_state_dict_mapping(src_model, dst_model, [x]) - state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore + assert mapping is not None, "Model conversion failed" + state_dict = convert_state_dict( + source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping + ) return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -41,9 +40,9 @@ def main(): help="Path for the output file", ) args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(args.source).text_encoder - tensors = convert(src_model) - save_file(tensors, args.output_file) + src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder # type: ignore + tensors = convert(src_model=src_model) # type: ignore + save_to_safetensors(path=args.output_file, tensors=tensors) if __name__ == "__main__": diff --git a/scripts/convert-controlnet-weights.py b/scripts/convert-controlnet-weights.py index c2b6142..ae0cd5e 100644 --- a/scripts/convert-controlnet-weights.py +++ b/scripts/convert-controlnet-weights.py @@ -1,10 +1,10 @@ import torch -from diffusers import ControlNetModel -from safetensors.torch import save_file +from diffusers import ControlNetModel # type: ignore from refiners.fluxion.utils import ( forward_order_of_execution, verify_shape_match, convert_state_dict, + save_to_safetensors, ) from refiners.foundationals.latent_diffusion.controlnet import Controlnet from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver @@ -16,16 +16,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: controlnet = Controlnet(name="mycn") condition = torch.randn(1, 3, 512, 512) - controlnet.set_controlnet_condition(condition) + controlnet.set_controlnet_condition(condition=condition) - unet = UNet(4, clip_embedding_dim=768) - unet.insert(0, controlnet) + unet = UNet(in_channels=4, clip_embedding_dim=768) + unet.insert(index=0, module=controlnet) clip_text_embedding = torch.rand(1, 77, 768) - unet.set_clip_text_embedding(clip_text_embedding) + unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) scheduler = DPMSolver(num_inference_steps=10) - timestep = scheduler.timesteps[0].unsqueeze(0) - unet.set_timestep(timestep.unsqueeze(0)) + timestep = scheduler.timesteps[0].unsqueeze(dim=0) + unet.set_timestep(timestep=timestep.unsqueeze(dim=0)) x = torch.randn(1, 4, 64, 64) @@ -33,8 +33,8 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: # to diffusers in order, since we compute the residuals inline instead of # in a separate step. - source_order = forward_order_of_execution(controlnet_src, (x, timestep, clip_text_embedding, condition)) - target_order = forward_order_of_execution(controlnet, (x,)) + source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore + target_order = forward_order_of_execution(module=controlnet, example_args=(x,)) broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320]))) @@ -162,7 +162,7 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: assert target_order[broken_k] == expected_target_order source_order[broken_k] = fixed_source_order - assert verify_shape_match(source_order, target_order) + assert verify_shape_match(source_order=source_order, target_order=target_order) mapping: dict[str, str] = {} for model_type_shape in source_order: @@ -170,12 +170,16 @@ def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]: target_keys = target_order[model_type_shape] mapping.update(zip(target_keys, source_keys)) - state_dict = convert_state_dict(controlnet_src.state_dict(), controlnet.state_dict(), state_dict_mapping=mapping) + state_dict = convert_state_dict( + source_state_dict=controlnet_src.state_dict(), + target_state_dict=controlnet.state_dict(), + state_dict_mapping=mapping, + ) return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -194,9 +198,9 @@ def main(): help="Path for the output file", ) args = parser.parse_args() - controlnet_src = ControlNetModel.from_pretrained(args.source) - tensors = convert(controlnet_src) - save_file(tensors, args.output_file) + controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore + tensors = convert(controlnet_src=controlnet_src) # type: ignore + save_to_safetensors(path=args.output_file, tensors=tensors) if __name__ == "__main__": diff --git a/scripts/convert-lora-weights.py b/scripts/convert-lora-weights.py index 1323945..8d45cb0 100644 --- a/scripts/convert-lora-weights.py +++ b/scripts/convert-lora-weights.py @@ -1,6 +1,7 @@ # Note: this conversion script currently only support simple LoRAs which adapt # the UNet's attentions such as https://huggingface.co/pcuenq/pokemon-lora +from typing import cast import torch from torch.nn.init import zeros_ from torch.nn import Parameter as TorchParameter @@ -13,7 +14,7 @@ from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras from refiners.adapters.lora import Lora from refiners.fluxion.utils import create_state_dict_mapping -from diffusers import DiffusionPipeline +from diffusers import DiffusionPipeline # type: ignore def get_weight(linear: fl.Linear) -> torch.Tensor: @@ -24,17 +25,17 @@ def get_weight(linear: fl.Linear) -> torch.Tensor: def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, torch.Tensor]: weights: list[torch.Tensor] = [] for lora in module.layers(layer_type=Lora): - linears = list(lora.layers(fl.Linear)) + linears = list(lora.layers(layer_type=fl.Linear)) assert len(linears) == 2 - weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight) - return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)} + weights.extend((get_weight(linear=linears[1]), get_weight(linear=linears[0]))) # aka (up_weight, down_weight) + return {f"{key_prefix}{i:03d}": w for i, w in enumerate(iterable=weights)} @torch.no_grad() def process(source: str, base_model: str, output_file: str) -> None: diffusers_state_dict = torch.load(source, map_location="cpu") # type: ignore - diffusers_sd = DiffusionPipeline.from_pretrained(base_model) # type: ignore - diffusers_model = diffusers_sd.unet + diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore + diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore refiners_model = UNet(in_channels=4, clip_embedding_dim=768) target = LoraTarget.CrossAttention @@ -44,21 +45,23 @@ def process(source: str, base_model: str, output_file: str) -> None: ].shape[0] x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor([0]) + timestep = torch.tensor(data=[0]) clip_text_embeddings = torch.randn(1, 77, 768) - refiners_model.set_timestep(timestep) - refiners_model.set_clip_text_embedding(clip_text_embeddings) + refiners_model.set_timestep(timestep=timestep) + refiners_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) refiners_args = (x,) diffusers_args = (x, timestep, clip_text_embeddings) - diffusers_to_refiners = create_state_dict_mapping(refiners_model, diffusers_model, refiners_args, diffusers_args) - assert diffusers_to_refiners + diffusers_to_refiners = create_state_dict_mapping( + source_model=refiners_model, target_model=diffusers_model, source_args=refiners_args, target_args=diffusers_args + ) + assert diffusers_to_refiners is not None - apply_loras_to_target(refiners_model, target=LoraTarget(target), rank=rank, scale=1.0) + apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0) for layer in refiners_model.layers(layer_type=Lora): - zeros_(layer.Linear_1.weight) + zeros_(tensor=layer.Linear_1.weight) targets = {k.split("_lora.")[0] for k in diffusers_state_dict.keys()} for target_k in targets: @@ -66,23 +69,24 @@ def process(source: str, base_model: str, output_file: str) -> None: r = [v for k, v in diffusers_to_refiners.items() if k.startswith(f"{k_p}.{k_s}")] assert len(r) == 1 orig_k = r[0] - orig_path = orig_k.split(".") + orig_path = orig_k.split(sep=".") p = refiners_model for seg in orig_path[:-1]: p = p[seg] + assert isinstance(p, fl.Chain) last_seg = ( "LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}" ) - p_down = TorchParameter(diffusers_state_dict[f"{target_k}_lora.down.weight"]) - p_up = TorchParameter(diffusers_state_dict[f"{target_k}_lora.up.weight"]) + p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"]) + p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"]) p[last_seg].Lora.load_weights(p_down, p_up) - state_dict = build_loras_safetensors(refiners_model, key_prefix="unet.") + state_dict = build_loras_safetensors(module=refiners_model, key_prefix="unet.") assert len(state_dict) == 320 - save_to_safetensors(output_file, tensors=state_dict, metadata=metadata) + save_to_safetensors(path=output_file, tensors=state_dict, metadata=metadata) -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() diff --git a/scripts/convert-loras-to-sdwebui.py b/scripts/convert-loras-to-sdwebui.py index 8d4baea..931c10b 100644 --- a/scripts/convert-loras-to-sdwebui.py +++ b/scripts/convert-loras-to-sdwebui.py @@ -1,4 +1,8 @@ -from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors, save_to_safetensors +from refiners.fluxion.utils import ( + load_from_safetensors, + load_metadata_from_safetensors, + save_to_safetensors, +) from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.latent_diffusion.unet import UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget @@ -8,33 +12,33 @@ from refiners.fluxion.utils import create_state_dict_mapping import torch -from diffusers import DiffusionPipeline -from diffusers.models.unet_2d_condition import UNet2DConditionModel -from transformers.models.clip.modeling_clip import CLIPTextModel +from diffusers import DiffusionPipeline # type: ignore +from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore +from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore @torch.no_grad() def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None: x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor([0]) + timestep = torch.tensor(data=[0]) clip_text_embeddings = torch.randn(1, 77, 768) src_args = (x, timestep, clip_text_embeddings) - dst_model.set_timestep(timestep) - dst_model.set_clip_text_embedding(clip_text_embeddings) + dst_model.set_timestep(timestep=timestep) + dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) dst_args = (x,) - return create_state_dict_mapping(src_model, dst_model, src_args, dst_args) # type: ignore + return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore @torch.no_grad() def create_text_encoder_mapping(src_model: CLIPTextModel, dst_model: CLIPTextEncoderL) -> dict[str, str] | None: x = dst_model.tokenizer("Nice cat", sequence_length=77) - return create_state_dict_mapping(src_model, dst_model, [x]) # type: ignore + return create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -61,11 +65,11 @@ def main(): ) args = parser.parse_args() - metadata = load_metadata_from_safetensors(args.input_file) + metadata = load_metadata_from_safetensors(path=args.input_file) assert metadata is not None - tensors = load_from_safetensors(args.input_file) + tensors = load_from_safetensors(path=args.input_file) - diffusers_sd = DiffusionPipeline.from_pretrained(args.sd15) # type: ignore + diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.sd15) # type: ignore state_dict: dict[str, torch.Tensor] = {} @@ -110,16 +114,16 @@ def main(): # Compute the corresponding diffusers' keys where LoRA layers must be applied lora_injection_points: list[str] = [ refiners_to_diffusers[submodule_to_key[linear]] - for target in [LoraTarget(t) for t in meta_value.split(",")] + for target in [LoraTarget(t) for t in meta_value.split(sep=",")] for layer in dst_model.layers(layer_type=target.get_class()) - for linear in layer.layers(fl.Linear) + for linear in layer.layers(layer_type=fl.Linear) ] lora_weights = [w for w in [tensors[k] for k in sorted(tensors) if k.startswith(key_prefix)]] assert len(lora_injection_points) == len(lora_weights) // 2 # Map LoRA weights to each key using SD-WebUI conventions (proper prefix and suffix, underscores) - for i, diffusers_key in enumerate(lora_injection_points): + for i, diffusers_key in enumerate(iterable=lora_injection_points): lora_key = lora_prefix + diffusers_key.replace(".", "_") # Note: no ".alpha" weights (those are used to scale the LoRA by alpha/rank). Refiners uses a scale = 1.0 # by default (see `lora_calc_updown` in SD-WebUI for more details) @@ -127,7 +131,7 @@ def main(): state_dict[lora_key + ".lora_down.weight"] = lora_weights[2 * i + 1] assert state_dict - save_to_safetensors(args.output_file, state_dict) + save_to_safetensors(path=args.output_file, tensors=state_dict) if __name__ == "__main__": diff --git a/scripts/convert-sd-lda-weights.py b/scripts/convert-sd-lda-weights.py index b7cfccc..e800292 100644 --- a/scripts/convert-sd-lda-weights.py +++ b/scripts/convert-sd-lda-weights.py @@ -1,13 +1,13 @@ import torch -from safetensors.torch import save_file from refiners.fluxion.utils import ( create_state_dict_mapping, convert_state_dict, + save_to_safetensors, ) -from diffusers import DiffusionPipeline -from diffusers.models.autoencoder_kl import AutoencoderKL +from diffusers import DiffusionPipeline # type: ignore +from diffusers.models.autoencoder_kl import AutoencoderKL # type: ignore from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder @@ -16,12 +16,15 @@ from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusion def convert(src_model: AutoencoderKL) -> dict[str, torch.Tensor]: dst_model = LatentDiffusionAutoencoder() x = torch.randn(1, 3, 512, 512) - mapping = create_state_dict_mapping(src_model, dst_model, [x]) - state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore + assert mapping is not None, "Model conversion failed" + state_dict = convert_state_dict( + source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping + ) return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -41,9 +44,9 @@ def main(): help="Path for the output file", ) args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(args.source).vae - tensors = convert(src_model) - save_file(tensors, args.output_file) + src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).vae # type: ignore + tensors = convert(src_model=src_model) # type: ignore + save_to_safetensors(path=args.output_file, tensors=tensors) if __name__ == "__main__": diff --git a/scripts/convert-sd-unet-inpainting-weights.py b/scripts/convert-sd-unet-inpainting-weights.py index 8b2c5f0..80a823d 100644 --- a/scripts/convert-sd-unet-inpainting-weights.py +++ b/scripts/convert-sd-unet-inpainting-weights.py @@ -1,13 +1,9 @@ import torch -from safetensors.torch import save_file -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, -) +from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors -from diffusers import StableDiffusionInpaintPipeline -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers import StableDiffusionInpaintPipeline # type: ignore +from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore from refiners.foundationals.latent_diffusion.unet import UNet @@ -17,20 +13,23 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]: dst_model = UNet(in_channels=9, clip_embedding_dim=768) x = torch.randn(1, 9, 32, 32) - timestep = torch.tensor([0]) + timestep = torch.tensor(data=[0]) clip_text_embeddings = torch.randn(1, 77, 768) src_args = (x, timestep, clip_text_embeddings) - dst_model.set_timestep(timestep) - dst_model.set_clip_text_embedding(clip_text_embeddings) + dst_model.set_timestep(timestep=timestep) + dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) dst_args = (x,) - mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args) - state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore + assert mapping is not None, "Model conversion failed" + state_dict = convert_state_dict( + source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping + ) return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -50,9 +49,9 @@ def main(): help="Path for the output file", ) args = parser.parse_args() - src_model = StableDiffusionInpaintPipeline.from_pretrained(args.source).unet - tensors = convert(src_model) - save_file(tensors, args.output_file) + src_model = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore + tensors = convert(src_model=src_model) # type: ignore + save_to_safetensors(path=args.output_file, tensors=tensors) if __name__ == "__main__": diff --git a/scripts/convert-sd-unet-weights.py b/scripts/convert-sd-unet-weights.py index 9654280..1c3e575 100644 --- a/scripts/convert-sd-unet-weights.py +++ b/scripts/convert-sd-unet-weights.py @@ -1,13 +1,9 @@ import torch -from safetensors.torch import save_file -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, -) +from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors -from diffusers import DiffusionPipeline -from diffusers.models.unet_2d_condition import UNet2DConditionModel +from diffusers import DiffusionPipeline # type: ignore +from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore from refiners.foundationals.latent_diffusion.unet import UNet @@ -17,20 +13,23 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]: dst_model = UNet(in_channels=4, clip_embedding_dim=768) x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor([0]) + timestep = torch.tensor(data=[0]) clip_text_embeddings = torch.randn(1, 77, 768) src_args = (x, timestep, clip_text_embeddings) - dst_model.set_timestep(timestep) - dst_model.set_clip_text_embedding(clip_text_embeddings) + dst_model.set_timestep(timestep=timestep) + dst_model.set_clip_text_embedding(clip_text_embedding=clip_text_embeddings) dst_args = (x,) - mapping = create_state_dict_mapping(src_model, dst_model, src_args, dst_args) - state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=src_args, target_args=dst_args) # type: ignore + assert mapping is not None, "Model conversion failed" + state_dict = convert_state_dict( + source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping + ) return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -50,9 +49,9 @@ def main(): help="Path for the output file", ) args = parser.parse_args() - src_model = DiffusionPipeline.from_pretrained(args.source).unet - tensors = convert(src_model) - save_file(tensors, args.output_file) + src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore + tensors = convert(src_model=src_model) # type: ignore + save_to_safetensors(path=args.output_file, tensors=tensors) if __name__ == "__main__": diff --git a/scripts/convert-sdxl-text-encoder-2.py b/scripts/convert-sdxl-text-encoder-2.py index 3f77cd2..ad67020 100644 --- a/scripts/convert-sdxl-text-encoder-2.py +++ b/scripts/convert-sdxl-text-encoder-2.py @@ -1,10 +1,7 @@ import torch from safetensors.torch import save_file # type: ignore -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, -) +from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict from diffusers import DiffusionPipeline # type: ignore from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore @@ -28,7 +25,7 @@ def convert(src_model: CLIPTextModel) -> dict[str, torch.Tensor]: return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -49,7 +46,7 @@ def main(): ) args = parser.parse_args() src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).text_encoder_2 # type: ignore - tensors = convert(src_model=src_model) + tensors = convert(src_model=src_model) # type: ignore save_file(tensors=tensors, filename=args.output_file) diff --git a/scripts/convert-sdxl-unet-weights.py b/scripts/convert-sdxl-unet-weights.py index 60ba93b..de07f05 100644 --- a/scripts/convert-sdxl-unet-weights.py +++ b/scripts/convert-sdxl-unet-weights.py @@ -1,10 +1,7 @@ import torch from safetensors.torch import save_file # type: ignore -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, -) +from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict from diffusers import DiffusionPipeline # type: ignore from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore @@ -17,7 +14,7 @@ def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]: dst_model = SDXLUNet(in_channels=4) x = torch.randn(1, 4, 32, 32) - timestep = torch.tensor([0]) + timestep = torch.tensor(data=[0]) clip_text_embeddings = torch.randn(1, 77, 2048) added_cond_kwargs = {"text_embeds": torch.randn(1, 1280), "time_ids": torch.randn(1, 6)} @@ -60,8 +57,8 @@ def main() -> None: ) args = parser.parse_args() src_model = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.source).unet # type: ignore - tensors = convert(src_model) - save_file(tensors, args.output_file) + tensors = convert(src_model=src_model) # type: ignore + save_file(tensors=tensors, filename=args.output_file) if __name__ == "__main__":