mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
208 lines
7.5 KiB
Python
208 lines
7.5 KiB
Python
import torch
|
|
from diffusers import ControlNetModel # type: ignore
|
|
from refiners.fluxion.utils import (
|
|
forward_order_of_execution,
|
|
verify_shape_match,
|
|
convert_state_dict,
|
|
save_to_safetensors,
|
|
)
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
|
|
|
|
|
@torch.no_grad()
|
|
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
|
controlnet = SD1Controlnet(name="mycn")
|
|
|
|
condition = torch.randn(1, 3, 512, 512)
|
|
controlnet.set_controlnet_condition(condition=condition)
|
|
|
|
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
|
unet.insert(index=0, module=controlnet)
|
|
clip_text_embedding = torch.rand(1, 77, 768)
|
|
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
|
|
|
scheduler = DPMSolver(num_inference_steps=10)
|
|
timestep = scheduler.timesteps[0].unsqueeze(dim=0)
|
|
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))
|
|
|
|
x = torch.randn(1, 4, 64, 64)
|
|
|
|
# We need the hack below because our implementation is not strictly equivalent
|
|
# to diffusers in order, since we compute the residuals inline instead of
|
|
# in a separate step.
|
|
|
|
source_order = forward_order_of_execution(module=controlnet_src, example_args=(x, timestep, clip_text_embedding, condition)) # type: ignore
|
|
target_order = forward_order_of_execution(module=controlnet, example_args=(x,))
|
|
|
|
broken_k = ("Conv2d", (torch.Size([320, 320, 1, 1]), torch.Size([320])))
|
|
|
|
expected_source_order = [
|
|
"down_blocks.0.attentions.0.proj_in",
|
|
"down_blocks.0.attentions.0.proj_out",
|
|
"down_blocks.0.attentions.1.proj_in",
|
|
"down_blocks.0.attentions.1.proj_out",
|
|
"controlnet_down_blocks.0",
|
|
"controlnet_down_blocks.1",
|
|
"controlnet_down_blocks.2",
|
|
"controlnet_down_blocks.3",
|
|
]
|
|
|
|
expected_target_order = [
|
|
"DownBlocks.Chain_1.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"DownBlocks.Chain_2.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"DownBlocks.Chain_2.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"DownBlocks.Chain_3.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"DownBlocks.Chain_3.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_4.Passthrough.Conv2d",
|
|
]
|
|
|
|
fixed_source_order = [
|
|
"controlnet_down_blocks.0",
|
|
"down_blocks.0.attentions.0.proj_in",
|
|
"down_blocks.0.attentions.0.proj_out",
|
|
"controlnet_down_blocks.1",
|
|
"down_blocks.0.attentions.1.proj_in",
|
|
"down_blocks.0.attentions.1.proj_out",
|
|
"controlnet_down_blocks.2",
|
|
"controlnet_down_blocks.3",
|
|
]
|
|
|
|
assert source_order[broken_k] == expected_source_order
|
|
assert target_order[broken_k] == expected_target_order
|
|
source_order[broken_k] = fixed_source_order
|
|
|
|
broken_k = ("Conv2d", (torch.Size([640, 640, 1, 1]), torch.Size([640])))
|
|
|
|
expected_source_order = [
|
|
"down_blocks.1.attentions.0.proj_in",
|
|
"down_blocks.1.attentions.0.proj_out",
|
|
"down_blocks.1.attentions.1.proj_in",
|
|
"down_blocks.1.attentions.1.proj_out",
|
|
"controlnet_down_blocks.4",
|
|
"controlnet_down_blocks.5",
|
|
"controlnet_down_blocks.6",
|
|
]
|
|
|
|
expected_target_order = [
|
|
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"DownBlocks.Chain_5.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"DownBlocks.Chain_5.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"DownBlocks.Chain_6.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"DownBlocks.Chain_6.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_7.Passthrough.Conv2d",
|
|
]
|
|
|
|
fixed_source_order = [
|
|
"down_blocks.1.attentions.0.proj_in",
|
|
"down_blocks.1.attentions.0.proj_out",
|
|
"controlnet_down_blocks.4",
|
|
"down_blocks.1.attentions.1.proj_in",
|
|
"down_blocks.1.attentions.1.proj_out",
|
|
"controlnet_down_blocks.5",
|
|
"controlnet_down_blocks.6",
|
|
]
|
|
|
|
assert source_order[broken_k] == expected_source_order
|
|
assert target_order[broken_k] == expected_target_order
|
|
source_order[broken_k] = fixed_source_order
|
|
|
|
broken_k = ("Conv2d", (torch.Size([1280, 1280, 1, 1]), torch.Size([1280])))
|
|
|
|
expected_source_order = [
|
|
"down_blocks.2.attentions.0.proj_in",
|
|
"down_blocks.2.attentions.0.proj_out",
|
|
"down_blocks.2.attentions.1.proj_in",
|
|
"down_blocks.2.attentions.1.proj_out",
|
|
"mid_block.attentions.0.proj_in",
|
|
"mid_block.attentions.0.proj_out",
|
|
"controlnet_down_blocks.7",
|
|
"controlnet_down_blocks.8",
|
|
"controlnet_down_blocks.9",
|
|
"controlnet_down_blocks.10",
|
|
"controlnet_down_blocks.11",
|
|
"controlnet_mid_block",
|
|
]
|
|
|
|
expected_target_order = [
|
|
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"DownBlocks.Chain_8.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"DownBlocks.Chain_8.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"DownBlocks.Chain_9.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"DownBlocks.Chain_9.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_10.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_11.Passthrough.Conv2d",
|
|
"DownBlocks.Chain_12.Passthrough.Conv2d",
|
|
"MiddleBlock.CLIPLCrossAttention.Chain.Chain_1.Conv2d",
|
|
"MiddleBlock.CLIPLCrossAttention.Chain.Chain_3.Conv2d",
|
|
"MiddleBlock.Passthrough.Conv2d",
|
|
]
|
|
|
|
fixed_source_order = [
|
|
"down_blocks.2.attentions.0.proj_in",
|
|
"down_blocks.2.attentions.0.proj_out",
|
|
"controlnet_down_blocks.7",
|
|
"down_blocks.2.attentions.1.proj_in",
|
|
"down_blocks.2.attentions.1.proj_out",
|
|
"controlnet_down_blocks.8",
|
|
"controlnet_down_blocks.9",
|
|
"controlnet_down_blocks.10",
|
|
"controlnet_down_blocks.11",
|
|
"mid_block.attentions.0.proj_in",
|
|
"mid_block.attentions.0.proj_out",
|
|
"controlnet_mid_block",
|
|
]
|
|
|
|
assert source_order[broken_k] == expected_source_order
|
|
assert target_order[broken_k] == expected_target_order
|
|
source_order[broken_k] = fixed_source_order
|
|
|
|
assert verify_shape_match(source_order=source_order, target_order=target_order)
|
|
|
|
mapping: dict[str, str] = {}
|
|
for model_type_shape in source_order:
|
|
source_keys = source_order[model_type_shape]
|
|
target_keys = target_order[model_type_shape]
|
|
mapping.update(zip(target_keys, source_keys))
|
|
|
|
state_dict = convert_state_dict(
|
|
source_state_dict=controlnet_src.state_dict(),
|
|
target_state_dict=controlnet.state_dict(),
|
|
state_dict_mapping=mapping,
|
|
)
|
|
|
|
return {k: v.half() for k, v in state_dict.items()}
|
|
|
|
|
|
def main() -> None:
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--from",
|
|
type=str,
|
|
dest="source",
|
|
required=True,
|
|
help="Source model",
|
|
)
|
|
parser.add_argument(
|
|
"--output-file",
|
|
type=str,
|
|
required=False,
|
|
default="output.safetensors",
|
|
help="Path for the output file",
|
|
)
|
|
args = parser.parse_args()
|
|
controlnet_src = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source) # type: ignore
|
|
tensors = convert(controlnet_src=controlnet_src) # type: ignore
|
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|