mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add option to override unet weights for conversion
This commit is contained in:
parent
8f614e7647
commit
d14c5bd5f8
|
@ -27,6 +27,9 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
if args.override_weights is not None:
|
||||
sd = torch.load(args.override_weights) # type: ignore
|
||||
source.load_state_dict(sd)
|
||||
source_in_channels: int = source.config.in_channels # type: ignore
|
||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||
|
@ -95,6 +98,15 @@ def main() -> None:
|
|||
" runwayml/stable-diffusion-v1-5"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--override-weights",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Path to a weights file to override the source model (keeping its config). "
|
||||
"This is useful for models distributed as .pth files."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
|
|
Loading…
Reference in a new issue