fix: bad add_argument
This commit is contained in:
parent
a88a55b8e8
commit
265e67bec8
|
@ -19,6 +19,7 @@ def main(argv: List[str]) -> None:
|
|||
"""Main entrypoint for training and inference."""
|
||||
# parse args
|
||||
args = parse_args(argv)
|
||||
print(args)
|
||||
|
||||
# stfu warnings
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
@ -45,7 +46,7 @@ def main(argv: List[str]) -> None:
|
|||
devices="auto",
|
||||
strategy="dp",
|
||||
max_epochs=args.epochs,
|
||||
precision=16,
|
||||
precision="bf16",
|
||||
log_every_n_steps=25,
|
||||
val_check_interval=100,
|
||||
benchmark=True,
|
||||
|
@ -65,6 +66,7 @@ def main(argv: List[str]) -> None:
|
|||
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
|
||||
|
||||
# save predictions to csv
|
||||
# TODO: define track upper bound
|
||||
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
|
||||
i = 0
|
||||
f.write("id,label\n")
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import List # TODO: update to python 3.11
|
|||
def parse_args(argv: List[str]) -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="pytorch lightning + classy vision TorchX example app",
|
||||
description="Train and inference for AIorNOT challenge",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -58,7 +58,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
|
|||
)
|
||||
parser.add_argument(
|
||||
"--skip_csv",
|
||||
desc="skip export test inference to csv file",
|
||||
action="store_true",
|
||||
help="skip export test inference to csv file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_ckpt",
|
||||
|
@ -68,13 +69,13 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--prefetch_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
default=3,
|
||||
help="prefetch factor for dataloaders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=4,
|
||||
default=8,
|
||||
help="number of workers for dataloaders",
|
||||
)
|
||||
|
||||
|
|
Reference in a new issue