mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add option to override unet weights for conversion
This commit is contained in:
parent
8f614e7647
commit
d14c5bd5f8
|
@ -27,6 +27,9 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
subfolder=args.subfolder,
|
subfolder=args.subfolder,
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
)
|
)
|
||||||
|
if args.override_weights is not None:
|
||||||
|
sd = torch.load(args.override_weights) # type: ignore
|
||||||
|
source.load_state_dict(sd)
|
||||||
source_in_channels: int = source.config.in_channels # type: ignore
|
source_in_channels: int = source.config.in_channels # type: ignore
|
||||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||||
|
@ -95,6 +98,15 @@ def main() -> None:
|
||||||
" runwayml/stable-diffusion-v1-5"
|
" runwayml/stable-diffusion-v1-5"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--override-weights",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Path to a weights file to override the source model (keeping its config). "
|
||||||
|
"This is useful for models distributed as .pth files."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--to",
|
"--to",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
Loading…
Reference in a new issue