fix: bad add_argument
This commit is contained in:
parent
a88a55b8e8
commit
265e67bec8
|
@ -19,6 +19,7 @@ def main(argv: List[str]) -> None:
|
||||||
"""Main entrypoint for training and inference."""
|
"""Main entrypoint for training and inference."""
|
||||||
# parse args
|
# parse args
|
||||||
args = parse_args(argv)
|
args = parse_args(argv)
|
||||||
|
print(args)
|
||||||
|
|
||||||
# stfu warnings
|
# stfu warnings
|
||||||
torch.set_float32_matmul_precision("medium")
|
torch.set_float32_matmul_precision("medium")
|
||||||
|
@ -45,7 +46,7 @@ def main(argv: List[str]) -> None:
|
||||||
devices="auto",
|
devices="auto",
|
||||||
strategy="dp",
|
strategy="dp",
|
||||||
max_epochs=args.epochs,
|
max_epochs=args.epochs,
|
||||||
precision=16,
|
precision="bf16",
|
||||||
log_every_n_steps=25,
|
log_every_n_steps=25,
|
||||||
val_check_interval=100,
|
val_check_interval=100,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
|
@ -65,6 +66,7 @@ def main(argv: List[str]) -> None:
|
||||||
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
|
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
|
||||||
|
|
||||||
# save predictions to csv
|
# save predictions to csv
|
||||||
|
# TODO: define track upper bound
|
||||||
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
|
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
|
||||||
i = 0
|
i = 0
|
||||||
f.write("id,label\n")
|
f.write("id,label\n")
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import List # TODO: update to python 3.11
|
||||||
def parse_args(argv: List[str]) -> argparse.Namespace:
|
def parse_args(argv: List[str]) -> argparse.Namespace:
|
||||||
"""Parse command line arguments."""
|
"""Parse command line arguments."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="pytorch lightning + classy vision TorchX example app",
|
description="Train and inference for AIorNOT challenge",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -58,7 +58,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip_csv",
|
"--skip_csv",
|
||||||
desc="skip export test inference to csv file",
|
action="store_true",
|
||||||
|
help="skip export test inference to csv file",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load_ckpt",
|
"--load_ckpt",
|
||||||
|
@ -68,13 +69,13 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prefetch_factor",
|
"--prefetch_factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=3,
|
||||||
help="prefetch factor for dataloaders",
|
help="prefetch factor for dataloaders",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_workers",
|
"--num_workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=8,
|
||||||
help="number of workers for dataloaders",
|
help="number of workers for dataloaders",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Reference in a new issue