From d14c5bd5f8435f294d3244d28ab0c4246d50a5b2 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 21 Feb 2024 18:05:12 +0100 Subject: [PATCH] add option to override unet weights for conversion --- scripts/conversion/convert_diffusers_unet.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index d46f7e1..014db37 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -27,6 +27,9 @@ def setup_converter(args: Args) -> ModelConverter: subfolder=args.subfolder, low_cpu_mem_usage=False, ) + if args.override_weights is not None: + sd = torch.load(args.override_weights) # type: ignore + source.load_state_dict(sd) source_in_channels: int = source.config.in_channels # type: ignore source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore @@ -95,6 +98,15 @@ def main() -> None: " runwayml/stable-diffusion-v1-5" ), ) + parser.add_argument( + "--override-weights", + type=str, + default=None, + help=( + "Path to a weights file to override the source model (keeping its config). " + "This is useful for models distributed as .pth files." + ), + ) parser.add_argument( "--to", type=str,