add threshold for clip image encoder conversion

This commit is contained in:
Cédric Deltheil 2023-09-07 14:46:15 +02:00 committed by Cédric Deltheil
parent c6fadd1c81
commit 946e7c2974

View file

@ -15,6 +15,7 @@ class Args(argparse.Namespace):
output_path: str | None output_path: str | None
half: bool half: bool
verbose: bool verbose: bool
threshold: float
def setup_converter(args: Args) -> ModelConverter: def setup_converter(args: Args) -> ModelConverter:
@ -79,7 +80,7 @@ def setup_converter(args: Args) -> ModelConverter:
# Ad hoc post-conversion steps # Ad hoc post-conversion steps
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
assert converter.compare_models((x,), threshold=1e-2) assert converter.compare_models((x,), threshold=args.threshold)
return converter return converter
@ -122,6 +123,7 @@ def main() -> None:
default=False, default=False,
help="Prints additional information during conversion. Default: False", help="Prints additional information during conversion. Default: False",
) )
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
args = parser.parse_args(namespace=Args()) args = parser.parse_args(namespace=Args())
if args.output_path is None: if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors" args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"