fix: bad add_argument

This commit is contained in:
Laurent Fainsin 2023-02-10 16:20:15 +01:00
parent a88a55b8e8
commit 265e67bec8
2 changed files with 8 additions and 5 deletions

View file

@ -19,6 +19,7 @@ def main(argv: List[str]) -> None:
"""Main entrypoint for training and inference."""
# parse args
args = parse_args(argv)
print(args)
# stfu warnings
torch.set_float32_matmul_precision("medium")
@ -45,7 +46,7 @@ def main(argv: List[str]) -> None:
devices="auto",
strategy="dp",
max_epochs=args.epochs,
precision=16,
precision="bf16",
log_every_n_steps=25,
val_check_interval=100,
benchmark=True,
@ -65,6 +66,7 @@ def main(argv: List[str]) -> None:
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
# save predictions to csv
# TODO: define track upper bound
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
i = 0
f.write("id,label\n")

View file

@ -5,7 +5,7 @@ from typing import List # TODO: update to python 3.11
def parse_args(argv: List[str]) -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="pytorch lightning + classy vision TorchX example app",
description="Train and inference for AIorNOT challenge",
)
parser.add_argument(
@ -58,7 +58,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--skip_csv",
desc="skip export test inference to csv file",
action="store_true",
help="skip export test inference to csv file",
)
parser.add_argument(
"--load_ckpt",
@ -68,13 +69,13 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
parser.add_argument(
"--prefetch_factor",
type=int,
default=2,
default=3,
help="prefetch factor for dataloaders",
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
default=8,
help="number of workers for dataloaders",
)