diff --git a/.gitignore b/.gitignore index 32a40fb..cdc3479 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ wandb/ images/ +checkpoints/ *.pth *.onnx diff --git a/.vscode/launch.json b/.vscode/launch.json index bccf130..178c089 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,4 +19,4 @@ "justMyCode": true } ] -} +} \ No newline at end of file diff --git a/src/predict.py b/src/predict.py index 98a48eb..a69f7bc 100755 --- a/src/predict.py +++ b/src/predict.py @@ -38,19 +38,6 @@ def get_args(): return parser.parse_args() -def predict_img(net, img, device): - img = img.unsqueeze(0) - img = img.to(device=device, dtype=torch.float32) - - net.eval() - with torch.inference_mode(): - output = net(img) - # preds = torch.sigmoid(output)[0] - # full_mask = output.squeeze(0).cpu() - - return np.asarray(output.squeeze().cpu()) - - if __name__ == "__main__": args = get_args() @@ -81,8 +68,17 @@ if __name__ == "__main__": img = aug["image"] logging.info(f"Predicting image {args.input}") - mask = predict_img(net=net, img=img, device=device) + img = img.unsqueeze(0).to(device=device, dtype=torch.float32) + + net.eval() + with torch.inference_mode(): + mask = net(img) + mask = torch.sigmoid(mask)[0] + mask = mask.cpu() + mask = mask.squeeze() + mask = mask > 0.5 + mask = np.asarray(mask) logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask, "L") + mask = Image.fromarray(mask) mask.save(args.output) diff --git a/src/train.py b/src/train.py index 57e1179..c717652 100644 --- a/src/train.py +++ b/src/train.py @@ -26,7 +26,7 @@ def main(): DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", - FEATURES=[64, 128, 256, 512], + FEATURES=[16, 32, 64, 128], N_CHANNELS=3, N_CLASSES=1, AMP=True, @@ -35,8 +35,8 @@ def main(): DEVICE="cuda", WORKERS=8, EPOCHS=5, - BATCH_SIZE=16, - LEARNING_RATE=1e-5, + BATCH_SIZE=64, + LEARNING_RATE=1e-4, IMG_SIZE=512, SPHERES=5, ), @@ -50,7 +50,8 @@ def main(): # 0. Create network net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) - wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad) + wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) + wandb.watch(net, log_freq=100) # transfer network to device net.to(device=device) @@ -80,6 +81,11 @@ def main(): # 2. Create datasets ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) + # ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train) + # ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid) + + # ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k]) + # ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers # 3. Create data loaders train_loader = DataLoader( @@ -99,24 +105,24 @@ def main(): ) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp - optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9) + optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() # save model.pth - wandb.watch(net, log_freq=100) + torch.save(net.state_dict(), "checkpoints/model-0.pth") artifact = wandb.Artifact("pth", type="model") - artifact.add_file("model.pth") + artifact.add_file("checkpoints/model-0.pth") wandb.run.log_artifact(artifact) # save model.onxx dummy_input = torch.randn( 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) - torch.onnx.export(net, dummy_input, "model.onnx") + torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") artifact = wandb.Artifact("onnx", type="model") - artifact.add_file("model.onnx") + artifact.add_file("checkpoints/model-0.onnx") wandb.run.log_artifact(artifact) # print the config @@ -145,7 +151,7 @@ def main(): # forward with torch.cuda.amp.autocast(enabled=wandb.config.AMP): pred_masks = net(images) - train_loss = criterion(pred_masks, true_masks) + train_loss = criterion(true_masks, pred_masks) # backward optimizer.zero_grad(set_to_none=True) @@ -167,7 +173,7 @@ def main(): # Evaluation round val_score = evaluate(net, val_loader, device) - scheduler.step(val_score) + # scheduler.step(val_score) # log validation metrics wandb.log( @@ -177,18 +183,19 @@ def main(): ) # save weights when epoch end - torch.save(net.state_dict(), "model.pth") + torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth") artifact = wandb.Artifact("pth", type="model") - artifact.add_file("model.pth") + artifact.add_file(f"checkpoints/model-{epoch}.pth") wandb.run.log_artifact(artifact) # export model to onnx format dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) - torch.onnx.export(net, dummy_input, "model.onnx") + torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx") artifact = wandb.Artifact("onnx", type="model") - artifact.add_file("model.onnx") + artifact.add_file(f"checkpoints/model-{epoch}.onnx") wandb.run.log_artifact(artifact) + # stop wandb wandb.run.finish() except KeyboardInterrupt: diff --git a/src/unet/blocks.py b/src/unet/blocks.py index 1f4a854..b5b9267 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -70,7 +70,10 @@ class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1), + nn.Sigmoid(), + ) def forward(self, x): return self.conv(x)