mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
add threshold for clip image encoder conversion
This commit is contained in:
parent
c6fadd1c81
commit
946e7c2974
|
@ -15,6 +15,7 @@ class Args(argparse.Namespace):
|
|||
output_path: str | None
|
||||
half: bool
|
||||
verbose: bool
|
||||
threshold: float
|
||||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
|
@ -79,7 +80,7 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
# Ad hoc post-conversion steps
|
||||
class_embedding.parameter = torch.nn.Parameter(source.vision_model.embeddings.class_embedding.clone()) # type: ignore
|
||||
|
||||
assert converter.compare_models((x,), threshold=1e-2)
|
||||
assert converter.compare_models((x,), threshold=args.threshold)
|
||||
|
||||
return converter
|
||||
|
||||
|
@ -122,6 +123,7 @@ def main() -> None:
|
|||
default=False,
|
||||
help="Prints additional information during conversion. Default: False",
|
||||
)
|
||||
parser.add_argument("--threshold", type=float, default=1e-2, help="Threshold for model comparison. Default: 1e-2")
|
||||
args = parser.parse_args(namespace=Args())
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}-{args.subfolder}.safetensors"
|
||||
|
|
Loading…
Reference in a new issue