fix: bad add_argument

This commit is contained in:
Laurent Fainsin 2023-02-10 16:20:15 +01:00
parent a88a55b8e8
commit 265e67bec8
2 changed files with 8 additions and 5 deletions

View file

@ -19,6 +19,7 @@ def main(argv: List[str]) -> None:
"""Main entrypoint for training and inference.""" """Main entrypoint for training and inference."""
# parse args # parse args
args = parse_args(argv) args = parse_args(argv)
print(args)
# stfu warnings # stfu warnings
torch.set_float32_matmul_precision("medium") torch.set_float32_matmul_precision("medium")
@ -45,7 +46,7 @@ def main(argv: List[str]) -> None:
devices="auto", devices="auto",
strategy="dp", strategy="dp",
max_epochs=args.epochs, max_epochs=args.epochs,
precision=16, precision="bf16",
log_every_n_steps=25, log_every_n_steps=25,
val_check_interval=100, val_check_interval=100,
benchmark=True, benchmark=True,
@ -65,6 +66,7 @@ def main(argv: List[str]) -> None:
test_results = trainer.predict(model, dataloaders=model.test_dataloader()) test_results = trainer.predict(model, dataloaders=model.test_dataloader())
# save predictions to csv # save predictions to csv
# TODO: define track upper bound
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f: with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
i = 0 i = 0
f.write("id,label\n") f.write("id,label\n")

View file

@ -5,7 +5,7 @@ from typing import List # TODO: update to python 3.11
def parse_args(argv: List[str]) -> argparse.Namespace: def parse_args(argv: List[str]) -> argparse.Namespace:
"""Parse command line arguments.""" """Parse command line arguments."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="pytorch lightning + classy vision TorchX example app", description="Train and inference for AIorNOT challenge",
) )
parser.add_argument( parser.add_argument(
@ -58,7 +58,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
) )
parser.add_argument( parser.add_argument(
"--skip_csv", "--skip_csv",
desc="skip export test inference to csv file", action="store_true",
help="skip export test inference to csv file",
) )
parser.add_argument( parser.add_argument(
"--load_ckpt", "--load_ckpt",
@ -68,13 +69,13 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--prefetch_factor", "--prefetch_factor",
type=int, type=int,
default=2, default=3,
help="prefetch factor for dataloaders", help="prefetch factor for dataloaders",
) )
parser.add_argument( parser.add_argument(
"--num_workers", "--num_workers",
type=int, type=int,
default=4, default=8,
help="number of workers for dataloaders", help="number of workers for dataloaders",
) )