diff --git a/scripts/conversion/convert_diffusers_unet.py b/scripts/conversion/convert_diffusers_unet.py index d46f7e1..014db37 100644 --- a/scripts/conversion/convert_diffusers_unet.py +++ b/scripts/conversion/convert_diffusers_unet.py @@ -27,6 +27,9 @@ def setup_converter(args: Args) -> ModelConverter: subfolder=args.subfolder, low_cpu_mem_usage=False, ) + if args.override_weights is not None: + sd = torch.load(args.override_weights) # type: ignore + source.load_state_dict(sd) source_in_channels: int = source.config.in_channels # type: ignore source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore @@ -95,6 +98,15 @@ def main() -> None: " runwayml/stable-diffusion-v1-5" ), ) + parser.add_argument( + "--override-weights", + type=str, + default=None, + help=( + "Path to a weights file to override the source model (keeping its config). " + "This is useful for models distributed as .pth files." + ), + ) parser.add_argument( "--to", type=str,