feat(WIP): broken onnx prediction

Former-commit-id: cde7623ec486cf79a710949085aadd92d8a33a3e [formerly db0f1d0b9ea536c741f23a3b683e19a9335bcd35]
Former-commit-id: 7332ccb0f74c58c3a284a4568fb8f80a6d416cf4
This commit is contained in:
Laurent Fainsin 2022-07-05 12:06:12 +02:00
parent d785a5c6be
commit 0fb1d4fb7a
5 changed files with 64 additions and 125 deletions

6
.vscode/launch.json vendored
View file

@ -14,9 +14,11 @@
"--input", "--input",
"images/SM.png", "images/SM.png",
"--output", "--output",
"output.png", "output_onnx.png",
"--model",
"good.onnx",
], ],
"justMyCode": true "justMyCode": true
} }
] ]
} }

View file

@ -2,13 +2,12 @@ import argparse
import logging import logging
import albumentations as A import albumentations as A
import cv2
import numpy as np import numpy as np
import torch import torch
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image from PIL import Image
from unet import UNet
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -38,47 +37,35 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1) net = cv2.dnn.readNetFromONNX(args.model)
logging.info("onnx model loaded")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
logging.info("Transfering model to device")
net.to(device=device)
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info(f"Loading image {args.input}") logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB") input_img = cv2.imread(args.input, cv2.IMREAD_COLOR)
input_img = input_img.astype(np.float32)
# input_img = cv2.resize(input_img, (512, 512))
logging.info(f"Preprocessing image {args.input}") logging.info("converting to blob")
tf = A.Compose( input_blob = cv2.dnn.blobFromImage(
[ image=input_img,
A.ToFloat(max_value=255), scalefactor=1 / 255,
ToTensorV2(),
],
) )
aug = tf(image=np.asarray(img))
img = aug["image"]
logging.info(f"Predicting image {args.input}") net.setInput(input_blob)
img = img.unsqueeze(0).to(device=device, dtype=torch.float32) mask = net.forward()
mask = sigmoid(mask)
net.eval() mask = mask > 0.5
with torch.inference_mode(): mask = mask.astype(np.float32)
mask = net(img)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
logging.info(f"Saving prediction to {args.output}") logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask) mask = Image.fromarray(mask, "L")
mask.save(args.output) mask.save(args.output)

View file

@ -5,12 +5,13 @@ import torch
import yaml import yaml
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchmetrics import Dice
from tqdm import tqdm from tqdm import tqdm
import wandb import wandb
from src.utils.dataset import SphereDataset from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.dice import dice_coeff from utils.dice import DiceLoss
from utils.paste import RandomPaste from utils.paste import RandomPaste
class_labels = { class_labels = {
@ -37,8 +38,8 @@ if __name__ == "__main__":
PIN_MEMORY=True, PIN_MEMORY=True,
BENCHMARK=True, BENCHMARK=True,
DEVICE="cuda", DEVICE="cuda",
WORKERS=8, WORKERS=7,
EPOCHS=5, EPOCHS=1001,
BATCH_SIZE=16, BATCH_SIZE=16,
LEARNING_RATE=1e-4, LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8, WEIGHT_DECAY=1e-8,
@ -92,9 +93,13 @@ if __name__ == "__main__":
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
# 2.5. Create subset, if uncommented # 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
ds_train = torch.utils.data.Subset(ds_train, [0])
ds_valid = torch.utils.data.Subset(ds_valid, [0])
ds_test = torch.utils.data.Subset(ds_test, [0])
# 3. Create data loaders # 3. Create data loaders
train_loader = DataLoader( train_loader = DataLoader(
@ -131,18 +136,19 @@ if __name__ == "__main__":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
dice_loss = DiceLoss()
# save model.onxx # save model.onxx
dummy_input = torch.randn( dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device) ).to(device)
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model") artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model-0.onnx") artifact.add_file("checkpoints/model-0.onnx")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# log gradients and weights four time per epoch # log gradients and weights four time per epoch
wandb.watch(net, criterion, log_freq=100) wandb.watch(net, log_freq=100)
# print the config # print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
@ -176,6 +182,8 @@ if __name__ == "__main__":
pred_masks = net(images) pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks) train_loss = criterion(pred_masks, true_masks)
# compute loss
# backward # backward
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward() grad_scaler.scale(train_loss).backward()
@ -185,7 +193,7 @@ if __name__ == "__main__":
# compute metrics # compute metrics
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean() accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks) dice = dice_loss.coeff(pred_masks, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar # update tqdm progress bar
@ -197,13 +205,13 @@ if __name__ == "__main__":
{ {
"epoch": epoch - 1 + step / len(train_loader), "epoch": epoch - 1 + step / len(train_loader),
"train/accuracy": accuracy, "train/accuracy": accuracy,
"train/bce": train_loss, "train/loss": train_loss,
"train/dice": dice, "train/dice": dice,
"train/mae": mae, "train/mae": mae,
} }
) )
if step and (step % 250 == 0 or step == len(train_loader)): if step and (step % 100 == 0 or step == len(train_loader)):
# Evaluation round # Evaluation round
net.eval() net.eval()
accuracy = 0 accuracy = 0
@ -223,10 +231,10 @@ if __name__ == "__main__":
# compute metrics # compute metrics
val_loss += criterion(masks_pred, masks_true) val_loss += criterion(masks_pred, masks_true)
dice += dice_loss.coeff(pred_masks, true_masks)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean() accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar # update progress bar
pbar2.update(images.shape[0]) pbar2.update(images.shape[0])
@ -267,7 +275,7 @@ if __name__ == "__main__":
"val/predictions": table, "val/predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy, "val/accuracy": accuracy,
"val/bce": val_loss, "val/loss": val_loss,
"val/dice": dice, "val/dice": dice,
"val/mae": mae, "val/mae": mae,
}, },
@ -276,7 +284,7 @@ if __name__ == "__main__":
# update hyperparameters # update hyperparameters
net.train() net.train()
scheduler.step(dice) scheduler.step(train_loss)
# export model to onnx format when validation ends # export model to onnx format when validation ends
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
@ -304,10 +312,10 @@ if __name__ == "__main__":
# compute metrics # compute metrics
val_loss += criterion(masks_pred, masks_true) val_loss += criterion(masks_pred, masks_true)
dice += dice_loss.coeff(pred_masks, true_masks)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean() accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar # update progress bar
pbar3.update(images.shape[0]) pbar3.update(images.shape[0])
@ -347,7 +355,7 @@ if __name__ == "__main__":
{ {
"test/predictions": table, "test/predictions": table,
"test/accuracy": accuracy, "test/accuracy": accuracy,
"test/bce": val_loss, "test/loss": val_loss,
"test/dice": dice, "test/dice": dice,
"test/mae": mae, "test/mae": mae,
}, },

View file

@ -1,80 +1,22 @@
import torch import torch
from torch import Tensor import torch.nn as nn
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: class DiceLoss(nn.Module):
"""Average of Dice coefficient for all batches, or for a single mask. def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
Args: @staticmethod
input (Tensor): _description_ def coeff(inputs, targets, smooth=1):
target (Tensor): _description_ # comment out if your model contains a sigmoid or equivalent activation layer
reduce_batch_first (bool, optional): _description_. Defaults to False. inputs = torch.sigmoid(inputs)
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Raises: # flatten label and prediction tensors
ValueError: _description_ inputs = inputs.view(-1)
targets = targets.view(-1)
Returns: intersection = (inputs * targets).sum()
float: _description_ return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
"""
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first: def forward(self, inputs, targets, smooth=1):
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") return 1 - self.coeff(inputs, targets, smooth)
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all classes.
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Returns:
float: _description_
"""
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
"""Dice loss (objective to minimize) between 0 and 1.
Args:
input (Tensor): _description_
target (Tensor): _description_
multiclass (bool, optional): _description_. Defaults to False.
Returns:
float: _description_
"""
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)

View file

@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
nb, nb,
path_paste_img_dir, path_paste_img_dir,
path_paste_mask_dir, path_paste_mask_dir,
scale_range=(0.1, 0.2), scale_range=(0.05, 0.25),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):