diff --git a/src/predict.py b/src/predict.py index 587e11e..98a48eb 100755 --- a/src/predict.py +++ b/src/predict.py @@ -34,33 +34,28 @@ def get_args(): metavar="OUTPUT", help="Filenames of output images", ) - parser.add_argument( - "--threshold", - "-t", - type=float, - default=0.5, - help="Minimum probability value to consider a mask pixel white", - ) return parser.parse_args() -def predict_img(net, img, device, threshold): +def predict_img(net, img, device): img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) net.eval() with torch.inference_mode(): output = net(img) - preds = torch.sigmoid(output)[0] - full_mask = preds.cpu().squeeze() + # preds = torch.sigmoid(output)[0] + # full_mask = output.squeeze(0).cpu() - return np.asarray(full_mask > threshold) + return np.asarray(output.squeeze().cpu()) if __name__ == "__main__": args = get_args() + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + net = UNet(n_channels=3, n_classes=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -86,8 +81,8 @@ if __name__ == "__main__": img = aug["image"] logging.info(f"Predicting image {args.input}") - mask = predict_img(net=net, img=img, threshold=args.threshold, device=device) + mask = predict_img(net=net, img=img, device=device) logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask) - mask.write(args.output) + mask = Image.fromarray(mask, "L") + mask.save(args.output) diff --git a/src/train.py b/src/train.py index f3bda57..6e331b8 100644 --- a/src/train.py +++ b/src/train.py @@ -18,40 +18,39 @@ def main(): # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - # enable cuda, if possible - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # setup wandb wandb.init( project="U-Net", config=dict( - n_channels=3, - n_classes=1, - epochs=5, - batch_size=70, - learning_rate=1e-5, - amp=True, - num_workers=8, - pin_memory=True, - features=[16, 32, 64, 128], - benchmark=False, - device=device.type, DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017", DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", + FEATURES=[64, 128, 256, 512], + N_CHANNELS=3, + N_CLASSES=1, + AMP=True, + PIN_MEMORY=True, + BENCHMARK=False, + DEVICE="cuda", + WORKERS=8, + EPOCHS=5, + BATCH_SIZE=16, + LEARNING_RATE=1e-5, + IMG_SIZE=512, + SPHERES=5, ), ) + # create device + device = torch.device(wandb.config.device) + # enable cudnn benchmarking - torch.backends.cudnn.benchmark = wandb.config.benchmark + torch.backends.cudnn.benchmark = wandb.config.BENCHMARK # 0. Create network - net = UNet(n_channels=3, n_classes=wandb.config.n_classes, features=wandb.config.features) - wandb.config.params = sum(p.numel() for p in net.parameters() if p.requires_grad) - - # save initial model.pth - torch.save(net.state_dict(), "model.pth") + net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) + wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad) # transfer network to device net.to(device=device) @@ -59,10 +58,10 @@ def main(): # 1. Create transforms tf_train = A.Compose( [ - A.Resize(512, 512), + A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), A.Flip(), A.ColorJitter(), - RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), @@ -71,8 +70,8 @@ def main(): ) tf_valid = A.Compose( [ - A.Resize(512, 512), - RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), A.ToFloat(max_value=255), ToTensorV2(), ], @@ -83,16 +82,26 @@ def main(): ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders - loader_args = dict( - batch_size=wandb.config.batch_size, num_workers=wandb.config.num_workers, pin_memory=wandb.config.pin_memory + train_loader = DataLoader( + ds_train, + shuffle=True, + batch_size=wandb.config.BATCH_SIZE, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) + val_loader = DataLoader( + ds_valid, + shuffle=False, + drop_last=True, + batch_size=wandb.config.BATCH_SIZE, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, ) - train_loader = DataLoader(ds_train, shuffle=True, **loader_args) - val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) - # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP - optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.learning_rate, weight_decay=1e-8, momentum=0.9) + # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp + optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) - grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.amp) + grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() # save model.pth @@ -100,27 +109,31 @@ def main(): artifact = wandb.Artifact("pth", type="model") artifact.add_file("model.pth") wandb.run.log_artifact(artifact) - logging.info("model.pth saved") # save model.onxx - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) + dummy_input = torch.randn( + 1, wandb.config.n_channels, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True + ).to(device) torch.onnx.export(net, dummy_input, "model.onnx") artifact = wandb.Artifact("onnx", type="model") artifact.add_file("model.onnx") wandb.run.log_artifact(artifact) - logging.info("model.onnx saved") # print the config - logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") + logging.info( + f"""wandb config: + {yaml.dump(wandb.config.as_dict())} + """ + ) try: - for epoch in range(1, wandb.config.epochs + 1): - with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.epochs}", unit="img") as pbar: + for epoch in range(1, wandb.config.EPOCHS + 1): + with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: # Training round for step, (images, true_masks) in enumerate(train_loader): - assert images.shape[1] == net.n_channels, ( - f"Network has been defined with {net.n_channels} input channels, " + assert images.shape[1] == net.N_CHANNELS, ( + f"Network has been defined with {net.N_CHANNELS} input channels, " f"but loaded images have {images.shape[1]} channels. Please check that " "the images are loaded correctly." ) @@ -130,7 +143,7 @@ def main(): true_masks = true_masks.unsqueeze(1).to(device=device) # forward - with torch.cuda.amp.autocast(enabled=wandb.config.amp): + with torch.cuda.amp.autocast(enabled=wandb.config.AMP): pred_masks = net(images) train_loss = criterion(pred_masks, true_masks) @@ -163,36 +176,25 @@ def main(): } ) - logging.info( - f"""Validation ended: - Train Loss: {train_loss} - Valid Score: {val_score} - """ - ) - # save weights when epoch end torch.save(net.state_dict(), "model.pth") artifact = wandb.Artifact("pth", type="model") artifact.add_file("model.pth") wandb.run.log_artifact(artifact) - logging.info("model.pth saved") # export model to onnx format dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) torch.onnx.export(net, dummy_input, "model.onnx") - artifact = wandb.Artifact("pnnx", type="model") + artifact = wandb.Artifact("onnx", type="model") artifact.add_file("model.onnx") wandb.run.log_artifact(artifact) - logging.info("model.onnx saved") wandb.run.finish() except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") - logging.info("Saved interrupt") raise if __name__ == "__main__": - main() - # TODO: fix toutes les metrics, loss, accuracy, dice... + main() # TODO: fix toutes les metrics, loss, accuracy, dice...