feat(WIP): broken onnx prediction

Former-commit-id: cde7623ec486cf79a710949085aadd92d8a33a3e [formerly db0f1d0b9ea536c741f23a3b683e19a9335bcd35]
Former-commit-id: 7332ccb0f74c58c3a284a4568fb8f80a6d416cf4
This commit is contained in:
Laurent Fainsin 2022-07-05 12:06:12 +02:00
parent d785a5c6be
commit 0fb1d4fb7a
5 changed files with 64 additions and 125 deletions

4
.vscode/launch.json vendored
View file

@ -14,7 +14,9 @@
"--input",
"images/SM.png",
"--output",
"output.png",
"output_onnx.png",
"--model",
"good.onnx",
],
"justMyCode": true
}

View file

@ -2,13 +2,12 @@ import argparse
import logging
import albumentations as A
import cv2
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from unet import UNet
def get_args():
parser = argparse.ArgumentParser(
@ -38,47 +37,35 @@ def get_args():
return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
logging.info("Transfering model to device")
net.to(device=device)
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device))
net = cv2.dnn.readNetFromONNX(args.model)
logging.info("onnx model loaded")
logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB")
input_img = cv2.imread(args.input, cv2.IMREAD_COLOR)
input_img = input_img.astype(np.float32)
# input_img = cv2.resize(input_img, (512, 512))
logging.info(f"Preprocessing image {args.input}")
tf = A.Compose(
[
A.ToFloat(max_value=255),
ToTensorV2(),
],
logging.info("converting to blob")
input_blob = cv2.dnn.blobFromImage(
image=input_img,
scalefactor=1 / 255,
)
aug = tf(image=np.asarray(img))
img = aug["image"]
logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
net.eval()
with torch.inference_mode():
mask = net(img)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
net.setInput(input_blob)
mask = net.forward()
mask = sigmoid(mask)
mask = mask > 0.5
mask = np.asarray(mask)
mask = mask.astype(np.float32)
logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask)
mask = Image.fromarray(mask, "L")
mask.save(args.output)

View file

@ -5,12 +5,13 @@ import torch
import yaml
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from torchmetrics import Dice
from tqdm import tqdm
import wandb
from src.utils.dataset import SphereDataset
from unet import UNet
from utils.dice import dice_coeff
from utils.dice import DiceLoss
from utils.paste import RandomPaste
class_labels = {
@ -37,8 +38,8 @@ if __name__ == "__main__":
PIN_MEMORY=True,
BENCHMARK=True,
DEVICE="cuda",
WORKERS=8,
EPOCHS=5,
WORKERS=7,
EPOCHS=1001,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8,
@ -92,9 +93,13 @@ if __name__ == "__main__":
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
# 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
ds_train = torch.utils.data.Subset(ds_train, [0])
ds_valid = torch.utils.data.Subset(ds_valid, [0])
ds_test = torch.utils.data.Subset(ds_test, [0])
# 3. Create data loaders
train_loader = DataLoader(
@ -131,18 +136,19 @@ if __name__ == "__main__":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
dice_loss = DiceLoss()
# save model.onxx
dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device)
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model-0.onnx")
wandb.run.log_artifact(artifact)
# log gradients and weights four time per epoch
wandb.watch(net, criterion, log_freq=100)
wandb.watch(net, log_freq=100)
# print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
@ -176,6 +182,8 @@ if __name__ == "__main__":
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
# compute loss
# backward
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward()
@ -185,7 +193,7 @@ if __name__ == "__main__":
# compute metrics
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks)
dice = dice_loss.coeff(pred_masks, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar
@ -197,13 +205,13 @@ if __name__ == "__main__":
{
"epoch": epoch - 1 + step / len(train_loader),
"train/accuracy": accuracy,
"train/bce": train_loss,
"train/loss": train_loss,
"train/dice": dice,
"train/mae": mae,
}
)
if step and (step % 250 == 0 or step == len(train_loader)):
if step and (step % 100 == 0 or step == len(train_loader)):
# Evaluation round
net.eval()
accuracy = 0
@ -223,10 +231,10 @@ if __name__ == "__main__":
# compute metrics
val_loss += criterion(masks_pred, masks_true)
dice += dice_loss.coeff(pred_masks, true_masks)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar
pbar2.update(images.shape[0])
@ -267,7 +275,7 @@ if __name__ == "__main__":
"val/predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/loss": val_loss,
"val/dice": dice,
"val/mae": mae,
},
@ -276,7 +284,7 @@ if __name__ == "__main__":
# update hyperparameters
net.train()
scheduler.step(dice)
scheduler.step(train_loss)
# export model to onnx format when validation ends
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
@ -304,10 +312,10 @@ if __name__ == "__main__":
# compute metrics
val_loss += criterion(masks_pred, masks_true)
dice += dice_loss.coeff(pred_masks, true_masks)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar
pbar3.update(images.shape[0])
@ -347,7 +355,7 @@ if __name__ == "__main__":
{
"test/predictions": table,
"test/accuracy": accuracy,
"test/bce": val_loss,
"test/loss": val_loss,
"test/dice": dice,
"test/mae": mae,
},

View file

@ -1,80 +1,22 @@
import torch
from torch import Tensor
import torch.nn as nn
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all batches, or for a single mask.
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
@staticmethod
def coeff(inputs, targets, smooth=1):
# comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
Raises:
ValueError: _description_
# flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
Returns:
float: _description_
"""
assert input.size() == target.size()
intersection = (inputs * targets).sum()
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all classes.
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Returns:
float: _description_
"""
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
"""Dice loss (objective to minimize) between 0 and 1.
Args:
input (Tensor): _description_
target (Tensor): _description_
multiclass (bool, optional): _description_. Defaults to False.
Returns:
float: _description_
"""
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
def forward(self, inputs, targets, smooth=1):
return 1 - self.coeff(inputs, targets, smooth)

View file

@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
nb,
path_paste_img_dir,
path_paste_mask_dir,
scale_range=(0.1, 0.2),
scale_range=(0.05, 0.25),
always_apply=True,
p=1.0,
):