From 0fb1d4fb7a52fd95d5f68d1eccf5f8324f4fe08c Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 12:06:12 +0200 Subject: [PATCH] feat(WIP): broken onnx prediction Former-commit-id: cde7623ec486cf79a710949085aadd92d8a33a3e [formerly db0f1d0b9ea536c741f23a3b683e19a9335bcd35] Former-commit-id: 7332ccb0f74c58c3a284a4568fb8f80a6d416cf4 --- .vscode/launch.json | 6 ++-- src/predict.py | 53 +++++++++++---------------- src/train.py | 40 ++++++++++++--------- src/utils/dice.py | 88 ++++++++------------------------------------- src/utils/paste.py | 2 +- 5 files changed, 64 insertions(+), 125 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 178c089..4992c79 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,9 +14,11 @@ "--input", "images/SM.png", "--output", - "output.png", + "output_onnx.png", + "--model", + "good.onnx", ], "justMyCode": true } ] -} \ No newline at end of file +} diff --git a/src/predict.py b/src/predict.py index a69f7bc..0df1bab 100755 --- a/src/predict.py +++ b/src/predict.py @@ -2,13 +2,12 @@ import argparse import logging import albumentations as A +import cv2 import numpy as np import torch from albumentations.pytorch import ToTensorV2 from PIL import Image -from unet import UNet - def get_args(): parser = argparse.ArgumentParser( @@ -38,47 +37,35 @@ def get_args(): return parser.parse_args() +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + if __name__ == "__main__": args = get_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - net = UNet(n_channels=3, n_classes=1) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logging.info(f"Using device {device}") - - logging.info("Transfering model to device") - net.to(device=device) - - logging.info(f"Loading model {args.model}") - net.load_state_dict(torch.load(args.model, map_location=device)) + net = cv2.dnn.readNetFromONNX(args.model) + logging.info("onnx model loaded") logging.info(f"Loading image {args.input}") - img = Image.open(args.input).convert("RGB") + input_img = cv2.imread(args.input, cv2.IMREAD_COLOR) + input_img = input_img.astype(np.float32) + # input_img = cv2.resize(input_img, (512, 512)) - logging.info(f"Preprocessing image {args.input}") - tf = A.Compose( - [ - A.ToFloat(max_value=255), - ToTensorV2(), - ], + logging.info("converting to blob") + input_blob = cv2.dnn.blobFromImage( + image=input_img, + scalefactor=1 / 255, ) - aug = tf(image=np.asarray(img)) - img = aug["image"] - logging.info(f"Predicting image {args.input}") - img = img.unsqueeze(0).to(device=device, dtype=torch.float32) - - net.eval() - with torch.inference_mode(): - mask = net(img) - mask = torch.sigmoid(mask)[0] - mask = mask.cpu() - mask = mask.squeeze() - mask = mask > 0.5 - mask = np.asarray(mask) + net.setInput(input_blob) + mask = net.forward() + mask = sigmoid(mask) + mask = mask > 0.5 + mask = mask.astype(np.float32) logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask) + mask = Image.fromarray(mask, "L") mask.save(args.output) diff --git a/src/train.py b/src/train.py index 9ddac8e..2963c3b 100644 --- a/src/train.py +++ b/src/train.py @@ -5,12 +5,13 @@ import torch import yaml from albumentations.pytorch import ToTensorV2 from torch.utils.data import DataLoader +from torchmetrics import Dice from tqdm import tqdm import wandb from src.utils.dataset import SphereDataset from unet import UNet -from utils.dice import dice_coeff +from utils.dice import DiceLoss from utils.paste import RandomPaste class_labels = { @@ -37,8 +38,8 @@ if __name__ == "__main__": PIN_MEMORY=True, BENCHMARK=True, DEVICE="cuda", - WORKERS=8, - EPOCHS=5, + WORKERS=7, + EPOCHS=1001, BATCH_SIZE=16, LEARNING_RATE=1e-4, WEIGHT_DECAY=1e-8, @@ -92,9 +93,13 @@ if __name__ == "__main__": ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) - ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) + # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + + ds_train = torch.utils.data.Subset(ds_train, [0]) + ds_valid = torch.utils.data.Subset(ds_valid, [0]) + ds_test = torch.utils.data.Subset(ds_test, [0]) # 3. Create data loaders train_loader = DataLoader( @@ -131,18 +136,19 @@ if __name__ == "__main__": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() + dice_loss = DiceLoss() # save model.onxx dummy_input = torch.randn( 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) - torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") + torch.onnx.export(net, dummy_input, "checkpoints/model.onnx") artifact = wandb.Artifact("onnx", type="model") artifact.add_file("checkpoints/model-0.onnx") wandb.run.log_artifact(artifact) # log gradients and weights four time per epoch - wandb.watch(net, criterion, log_freq=100) + wandb.watch(net, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") @@ -176,6 +182,8 @@ if __name__ == "__main__": pred_masks = net(images) train_loss = criterion(pred_masks, true_masks) + # compute loss + # backward optimizer.zero_grad(set_to_none=True) grad_scaler.scale(train_loss).backward() @@ -185,7 +193,7 @@ if __name__ == "__main__": # compute metrics pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() accuracy = (true_masks == pred_masks_bin).float().mean() - dice = dice_coeff(pred_masks_bin, true_masks) + dice = dice_loss.coeff(pred_masks, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) # update tqdm progress bar @@ -197,13 +205,13 @@ if __name__ == "__main__": { "epoch": epoch - 1 + step / len(train_loader), "train/accuracy": accuracy, - "train/bce": train_loss, + "train/loss": train_loss, "train/dice": dice, "train/mae": mae, } ) - if step and (step % 250 == 0 or step == len(train_loader)): + if step and (step % 100 == 0 or step == len(train_loader)): # Evaluation round net.eval() accuracy = 0 @@ -223,10 +231,10 @@ if __name__ == "__main__": # compute metrics val_loss += criterion(masks_pred, masks_true) + dice += dice_loss.coeff(pred_masks, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar2.update(images.shape[0]) @@ -267,7 +275,7 @@ if __name__ == "__main__": "val/predictions": table, "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], "val/accuracy": accuracy, - "val/bce": val_loss, + "val/loss": val_loss, "val/dice": dice, "val/mae": mae, }, @@ -276,7 +284,7 @@ if __name__ == "__main__": # update hyperparameters net.train() - scheduler.step(dice) + scheduler.step(train_loss) # export model to onnx format when validation ends dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) @@ -304,10 +312,10 @@ if __name__ == "__main__": # compute metrics val_loss += criterion(masks_pred, masks_true) + dice += dice_loss.coeff(pred_masks, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar3.update(images.shape[0]) @@ -347,7 +355,7 @@ if __name__ == "__main__": { "test/predictions": table, "test/accuracy": accuracy, - "test/bce": val_loss, + "test/loss": val_loss, "test/dice": dice, "test/mae": mae, }, diff --git a/src/utils/dice.py b/src/utils/dice.py index a29f794..c4b201c 100644 --- a/src/utils/dice.py +++ b/src/utils/dice.py @@ -1,80 +1,22 @@ import torch -from torch import Tensor +import torch.nn as nn -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all batches, or for a single mask. +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceLoss, self).__init__() - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. + @staticmethod + def coeff(inputs, targets, smooth=1): + # comment out if your model contains a sigmoid or equivalent activation layer + inputs = torch.sigmoid(inputs) - Raises: - ValueError: _description_ + # flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) - Returns: - float: _description_ - """ - assert input.size() == target.size() + intersection = (inputs * targets).sum() + return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) - if input.dim() == 2 and reduce_batch_first: - raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") - - if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.reshape(-1), target.reshape(-1)) - sets_sum = torch.sum(input) + torch.sum(target) - - if sets_sum.item() == 0: - sets_sum = 2 * inter - - return (2 * inter + epsilon) / (sets_sum + epsilon) - else: - # compute and average metric for each batch element - dice = 0 - - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) - - return dice / input.shape[0] - - -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all classes. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - dice = 0 - - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) - - return dice / input.shape[1] - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float: - """Dice loss (objective to minimize) between 0 and 1. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - multiclass (bool, optional): _description_. Defaults to False. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - fn = multiclass_dice_coeff if multiclass else dice_coeff - - return 1 - fn(input, target, reduce_batch_first=True) + def forward(self, inputs, targets, smooth=1): + return 1 - self.coeff(inputs, targets, smooth) diff --git a/src/utils/paste.py b/src/utils/paste.py index 486a8ec..90ef0a8 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform): nb, path_paste_img_dir, path_paste_mask_dir, - scale_range=(0.1, 0.2), + scale_range=(0.05, 0.25), always_apply=True, p=1.0, ):