From 265e67bec8b74addaa039232c34e0384a8dd30e3 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 10 Feb 2023 16:20:15 +0100 Subject: [PATCH] fix: bad add_argument --- src/main.py | 4 +++- src/parse.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/main.py b/src/main.py index 48b00e5..a511448 100644 --- a/src/main.py +++ b/src/main.py @@ -19,6 +19,7 @@ def main(argv: List[str]) -> None: """Main entrypoint for training and inference.""" # parse args args = parse_args(argv) + print(args) # stfu warnings torch.set_float32_matmul_precision("medium") @@ -45,7 +46,7 @@ def main(argv: List[str]) -> None: devices="auto", strategy="dp", max_epochs=args.epochs, - precision=16, + precision="bf16", log_every_n_steps=25, val_check_interval=100, benchmark=True, @@ -65,6 +66,7 @@ def main(argv: List[str]) -> None: test_results = trainer.predict(model, dataloaders=model.test_dataloader()) # save predictions to csv + # TODO: define track upper bound with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f: i = 0 f.write("id,label\n") diff --git a/src/parse.py b/src/parse.py index 5617efc..1c9a307 100644 --- a/src/parse.py +++ b/src/parse.py @@ -5,7 +5,7 @@ from typing import List # TODO: update to python 3.11 def parse_args(argv: List[str]) -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( - description="pytorch lightning + classy vision TorchX example app", + description="Train and inference for AIorNOT challenge", ) parser.add_argument( @@ -58,7 +58,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace: ) parser.add_argument( "--skip_csv", - desc="skip export test inference to csv file", + action="store_true", + help="skip export test inference to csv file", ) parser.add_argument( "--load_ckpt", @@ -68,13 +69,13 @@ def parse_args(argv: List[str]) -> argparse.Namespace: parser.add_argument( "--prefetch_factor", type=int, - default=2, + default=3, help="prefetch factor for dataloaders", ) parser.add_argument( "--num_workers", type=int, - default=4, + default=8, help="number of workers for dataloaders", )