add option to override unet weights for conversion

This commit is contained in:
Pierre Chapuis 2024-02-21 18:05:12 +01:00
parent 8f614e7647
commit d14c5bd5f8

View file

@ -27,6 +27,9 @@ def setup_converter(args: Args) -> ModelConverter:
subfolder=args.subfolder,
low_cpu_mem_usage=False,
)
if args.override_weights is not None:
sd = torch.load(args.override_weights) # type: ignore
source.load_state_dict(sd)
source_in_channels: int = source.config.in_channels # type: ignore
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
@ -95,6 +98,15 @@ def main() -> None:
" runwayml/stable-diffusion-v1-5"
),
)
parser.add_argument(
"--override-weights",
type=str,
default=None,
help=(
"Path to a weights file to override the source model (keeping its config). "
"This is useful for models distributed as .pth files."
),
)
parser.add_argument(
"--to",
type=str,