mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-15 09:38:14 +00:00
90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
|
import argparse
|
||
|
from pathlib import Path
|
||
|
|
||
|
from convert_diffusers_unet import Args as UNetArgs, setup_converter as setup_unet_converter
|
||
|
from huggingface_hub import hf_hub_download # type: ignore
|
||
|
|
||
|
from refiners.fluxion.utils import load_from_safetensors, save_to_safetensors
|
||
|
|
||
|
|
||
|
class Args(argparse.Namespace):
|
||
|
source_path: str
|
||
|
output_path: str | None
|
||
|
subfolder: str
|
||
|
half: bool
|
||
|
verbose: bool
|
||
|
reference_unet_path: str
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
parser = argparse.ArgumentParser(description="Converts IC-Light patch weights to work with Refiners")
|
||
|
parser.add_argument(
|
||
|
"--from",
|
||
|
type=str,
|
||
|
dest="source_path",
|
||
|
default="lllyasviel/ic-light",
|
||
|
help=(
|
||
|
"Can be a path to a .bin file, a .safetensors file or a model name from the Hugging Face Hub. Default:"
|
||
|
" lllyasviel/ic-light"
|
||
|
),
|
||
|
)
|
||
|
parser.add_argument("--filename", type=str, default="iclight_sd15_fc.safetensors", help="Filename inside the hub.")
|
||
|
parser.add_argument(
|
||
|
"--to",
|
||
|
type=str,
|
||
|
dest="output_path",
|
||
|
default=None,
|
||
|
help=(
|
||
|
"Output path (.safetensors) for converted model. If not provided, the output path will be the same as the"
|
||
|
" source path."
|
||
|
),
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--verbose",
|
||
|
action="store_true",
|
||
|
default=False,
|
||
|
help="Prints additional information during conversion. Default: False",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--reference-unet-path",
|
||
|
type=str,
|
||
|
dest="reference_unet_path",
|
||
|
default="runwayml/stable-diffusion-v1-5",
|
||
|
help="Path to the reference UNet weights.",
|
||
|
)
|
||
|
args = parser.parse_args(namespace=Args())
|
||
|
if args.output_path is None:
|
||
|
args.output_path = f"{Path(args.filename).stem}-refiners.safetensors"
|
||
|
|
||
|
patch_file = (
|
||
|
Path(args.source_path)
|
||
|
if args.source_path.endswith(".safetensors")
|
||
|
else Path(
|
||
|
hf_hub_download(
|
||
|
repo_id=args.source_path,
|
||
|
filename=args.filename,
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
patch_weights = load_from_safetensors(patch_file)
|
||
|
|
||
|
unet_args = UNetArgs(
|
||
|
source_path=args.reference_unet_path,
|
||
|
subfolder="unet",
|
||
|
half=False,
|
||
|
verbose=False,
|
||
|
skip_init_check=True,
|
||
|
override_weights=None,
|
||
|
)
|
||
|
converter = setup_unet_converter(args=unet_args)
|
||
|
result = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage]
|
||
|
source_state_dict=patch_weights,
|
||
|
target_state_dict=converter.target_model.state_dict(),
|
||
|
state_dict_mapping=converter.get_mapping(),
|
||
|
)
|
||
|
save_to_safetensors(path=args.output_path, tensors=result)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|