diff --git a/.gitignore b/.gitignore index 9d60c97..da4c1bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ .venv/ .mypy_cache/ __pycache__/ + wandb/ +images/ *.pth +*.png +*.jpg diff --git a/.vscode/launch.json b/.vscode/launch.json index f7a9165..bccf130 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,9 +12,9 @@ "console": "integratedTerminal", "args": [ "--input", - "SM.png", + "images/SM.png", "--output", - "test.png", + "output.png", ], "justMyCode": true } diff --git a/SM.png.REMOVED.git-id b/SM.png.REMOVED.git-id deleted file mode 100644 index 21a947e..0000000 --- a/SM.png.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -c6d08aa612451072cfe32a3ee086d08342ed9dd9 \ No newline at end of file diff --git a/src/train.py b/src/train.py index 4f56bde..49f6dc5 100644 --- a/src/train.py +++ b/src/train.py @@ -18,7 +18,7 @@ from utils.paste import RandomPaste CHECKPOINT_DIR = Path("./checkpoints/") DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017") -DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/") +DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") @@ -41,7 +41,7 @@ def get_args(): dest="batch_size", metavar="B", type=int, - default=16, + default=70, help="Batch size", ) parser.add_argument( @@ -92,11 +92,14 @@ def main(): torch.backends.cudnn.benchmark = True # 0. Create network - net = UNet(n_channels=3, n_classes=args.classes) + features = [16, 32, 64, 128] + net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features) logging.info( f"""Network: input channels: {net.n_channels} output channels: {net.n_classes} + nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)} + features: {features} """ ) @@ -138,7 +141,7 @@ def main(): ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders - loader_args = dict(batch_size=args.batch_size, num_workers=6, pin_memory=True) + loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True) train_loader = DataLoader(ds_train, shuffle=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) @@ -159,9 +162,9 @@ def main(): ), ) wandb.watch(net, log_freq=100) - # artifact = wandb.Artifact("model", type="model") - # artifact.add_file("model.pth") - # run.log_artifact(artifact) + artifact_model = wandb.Artifact("model", type="model") + artifact_model.add_file("model.pth") + run.log_artifact(artifact_model) logging.info( f"""Starting training: @@ -228,8 +231,7 @@ def main(): print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") # save weights when epoch end - # torch.save(net.state_dict(), "model.pth") - # run.log_artifact(artifact) + torch.save(net.state_dict(), "model.pth") logging.info(f"model saved!") run.finish() diff --git a/test.png b/test.png deleted file mode 100644 index a0cad88..0000000 Binary files a/test.png and /dev/null differ