mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
add threshold for clip image encoder conversion
This commit is contained in:
parent
c6fadd1c81
commit
946e7c2974
|
@ -15,6 +15,7 @@ class Args(argparse.Namespace):
|
||||||
output_path: str | None
|
output_path: str | None
|
||||||
half: bool
|
half: bool
|
||||||
verbose: bool
|
verbose: bool
|
||||||
|
threshold: float
|
||||||
|
|
||||||
|
|
||||||
def setup_converter(args: Args) -> ModelConverter:
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
|
@ -79,7 +80,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
# Ad hoc post-conversion steps
|
# Ad hoc post-conversion steps
|
||||||
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
|
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
|
||||||
|
|
||||||
assert converter.compare_models((x,), threshold=1e-2)
|
assert converter.compare_models((x,), threshold=args.threshold)
|
||||||
|
|
||||||
return converter
|
return converter
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ def main() -> None:
|
||||||
default=False,
|
default=False,
|
||||||
help="Prints additional information during conversion. Default: False",
|
help="Prints additional information during conversion. Default: False",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
|
||||||
args = parser.parse_args(namespace=Args())
|
args = parser.parse_args(namespace=Args())
|
||||||
if args.output_path is None:
|
if args.output_path is None:
|
||||||
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||||
|
|
Loading…
Reference in a new issue