feat(WIP): broken onnx prediction
Former-commit-id: cde7623ec486cf79a710949085aadd92d8a33a3e [formerly db0f1d0b9ea536c741f23a3b683e19a9335bcd35] Former-commit-id: 7332ccb0f74c58c3a284a4568fb8f80a6d416cf4
This commit is contained in:
parent
d785a5c6be
commit
0fb1d4fb7a
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
|
@ -14,7 +14,9 @@
|
||||||
"--input",
|
"--input",
|
||||||
"images/SM.png",
|
"images/SM.png",
|
||||||
"--output",
|
"--output",
|
||||||
"output.png",
|
"output_onnx.png",
|
||||||
|
"--model",
|
||||||
|
"good.onnx",
|
||||||
],
|
],
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,13 +2,12 @@ import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from unet import UNet
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -38,47 +37,35 @@ def get_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid(x):
|
||||||
|
return 1 / (1 + np.exp(-x))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
net = UNet(n_channels=3, n_classes=1)
|
net = cv2.dnn.readNetFromONNX(args.model)
|
||||||
|
logging.info("onnx model loaded")
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
logging.info(f"Using device {device}")
|
|
||||||
|
|
||||||
logging.info("Transfering model to device")
|
|
||||||
net.to(device=device)
|
|
||||||
|
|
||||||
logging.info(f"Loading model {args.model}")
|
|
||||||
net.load_state_dict(torch.load(args.model, map_location=device))
|
|
||||||
|
|
||||||
logging.info(f"Loading image {args.input}")
|
logging.info(f"Loading image {args.input}")
|
||||||
img = Image.open(args.input).convert("RGB")
|
input_img = cv2.imread(args.input, cv2.IMREAD_COLOR)
|
||||||
|
input_img = input_img.astype(np.float32)
|
||||||
|
# input_img = cv2.resize(input_img, (512, 512))
|
||||||
|
|
||||||
logging.info(f"Preprocessing image {args.input}")
|
logging.info("converting to blob")
|
||||||
tf = A.Compose(
|
input_blob = cv2.dnn.blobFromImage(
|
||||||
[
|
image=input_img,
|
||||||
A.ToFloat(max_value=255),
|
scalefactor=1 / 255,
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
aug = tf(image=np.asarray(img))
|
|
||||||
img = aug["image"]
|
|
||||||
|
|
||||||
logging.info(f"Predicting image {args.input}")
|
net.setInput(input_blob)
|
||||||
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
|
mask = net.forward()
|
||||||
|
mask = sigmoid(mask)
|
||||||
net.eval()
|
|
||||||
with torch.inference_mode():
|
|
||||||
mask = net(img)
|
|
||||||
mask = torch.sigmoid(mask)[0]
|
|
||||||
mask = mask.cpu()
|
|
||||||
mask = mask.squeeze()
|
|
||||||
mask = mask > 0.5
|
mask = mask > 0.5
|
||||||
mask = np.asarray(mask)
|
mask = mask.astype(np.float32)
|
||||||
|
|
||||||
logging.info(f"Saving prediction to {args.output}")
|
logging.info(f"Saving prediction to {args.output}")
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask, "L")
|
||||||
mask.save(args.output)
|
mask.save(args.output)
|
||||||
|
|
40
src/train.py
40
src/train.py
|
@ -5,12 +5,13 @@ import torch
|
||||||
import yaml
|
import yaml
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torchmetrics import Dice
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.utils.dataset import SphereDataset
|
from src.utils.dataset import SphereDataset
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.dice import dice_coeff
|
from utils.dice import DiceLoss
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
class_labels = {
|
class_labels = {
|
||||||
|
@ -37,8 +38,8 @@ if __name__ == "__main__":
|
||||||
PIN_MEMORY=True,
|
PIN_MEMORY=True,
|
||||||
BENCHMARK=True,
|
BENCHMARK=True,
|
||||||
DEVICE="cuda",
|
DEVICE="cuda",
|
||||||
WORKERS=8,
|
WORKERS=7,
|
||||||
EPOCHS=5,
|
EPOCHS=1001,
|
||||||
BATCH_SIZE=16,
|
BATCH_SIZE=16,
|
||||||
LEARNING_RATE=1e-4,
|
LEARNING_RATE=1e-4,
|
||||||
WEIGHT_DECAY=1e-8,
|
WEIGHT_DECAY=1e-8,
|
||||||
|
@ -92,9 +93,13 @@ if __name__ == "__main__":
|
||||||
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||||
|
|
||||||
# 2.5. Create subset, if uncommented
|
# 2.5. Create subset, if uncommented
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
|
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||||
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||||
|
|
||||||
|
ds_train = torch.utils.data.Subset(ds_train, [0])
|
||||||
|
ds_valid = torch.utils.data.Subset(ds_valid, [0])
|
||||||
|
ds_test = torch.utils.data.Subset(ds_test, [0])
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
@ -131,18 +136,19 @@ if __name__ == "__main__":
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
dice_loss = DiceLoss()
|
||||||
|
|
||||||
# save model.onxx
|
# save model.onxx
|
||||||
dummy_input = torch.randn(
|
dummy_input = torch.randn(
|
||||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||||
).to(device)
|
).to(device)
|
||||||
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
|
torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file("checkpoints/model-0.onnx")
|
artifact.add_file("checkpoints/model-0.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# log gradients and weights four time per epoch
|
# log gradients and weights four time per epoch
|
||||||
wandb.watch(net, criterion, log_freq=100)
|
wandb.watch(net, log_freq=100)
|
||||||
|
|
||||||
# print the config
|
# print the config
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
|
@ -176,6 +182,8 @@ if __name__ == "__main__":
|
||||||
pred_masks = net(images)
|
pred_masks = net(images)
|
||||||
train_loss = criterion(pred_masks, true_masks)
|
train_loss = criterion(pred_masks, true_masks)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
grad_scaler.scale(train_loss).backward()
|
grad_scaler.scale(train_loss).backward()
|
||||||
|
@ -185,7 +193,7 @@ if __name__ == "__main__":
|
||||||
# compute metrics
|
# compute metrics
|
||||||
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
||||||
accuracy = (true_masks == pred_masks_bin).float().mean()
|
accuracy = (true_masks == pred_masks_bin).float().mean()
|
||||||
dice = dice_coeff(pred_masks_bin, true_masks)
|
dice = dice_loss.coeff(pred_masks, true_masks)
|
||||||
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||||
|
|
||||||
# update tqdm progress bar
|
# update tqdm progress bar
|
||||||
|
@ -197,13 +205,13 @@ if __name__ == "__main__":
|
||||||
{
|
{
|
||||||
"epoch": epoch - 1 + step / len(train_loader),
|
"epoch": epoch - 1 + step / len(train_loader),
|
||||||
"train/accuracy": accuracy,
|
"train/accuracy": accuracy,
|
||||||
"train/bce": train_loss,
|
"train/loss": train_loss,
|
||||||
"train/dice": dice,
|
"train/dice": dice,
|
||||||
"train/mae": mae,
|
"train/mae": mae,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if step and (step % 250 == 0 or step == len(train_loader)):
|
if step and (step % 100 == 0 or step == len(train_loader)):
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
net.eval()
|
net.eval()
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
|
@ -223,10 +231,10 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
val_loss += criterion(masks_pred, masks_true)
|
val_loss += criterion(masks_pred, masks_true)
|
||||||
|
dice += dice_loss.coeff(pred_masks, true_masks)
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar2.update(images.shape[0])
|
pbar2.update(images.shape[0])
|
||||||
|
@ -267,7 +275,7 @@ if __name__ == "__main__":
|
||||||
"val/predictions": table,
|
"val/predictions": table,
|
||||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||||
"val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
"val/bce": val_loss,
|
"val/loss": val_loss,
|
||||||
"val/dice": dice,
|
"val/dice": dice,
|
||||||
"val/mae": mae,
|
"val/mae": mae,
|
||||||
},
|
},
|
||||||
|
@ -276,7 +284,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# update hyperparameters
|
# update hyperparameters
|
||||||
net.train()
|
net.train()
|
||||||
scheduler.step(dice)
|
scheduler.step(train_loss)
|
||||||
|
|
||||||
# export model to onnx format when validation ends
|
# export model to onnx format when validation ends
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
|
@ -304,10 +312,10 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
val_loss += criterion(masks_pred, masks_true)
|
val_loss += criterion(masks_pred, masks_true)
|
||||||
|
dice += dice_loss.coeff(pred_masks, true_masks)
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar3.update(images.shape[0])
|
pbar3.update(images.shape[0])
|
||||||
|
@ -347,7 +355,7 @@ if __name__ == "__main__":
|
||||||
{
|
{
|
||||||
"test/predictions": table,
|
"test/predictions": table,
|
||||||
"test/accuracy": accuracy,
|
"test/accuracy": accuracy,
|
||||||
"test/bce": val_loss,
|
"test/loss": val_loss,
|
||||||
"test/dice": dice,
|
"test/dice": dice,
|
||||||
"test/mae": mae,
|
"test/mae": mae,
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,80 +1,22 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
class DiceLoss(nn.Module):
|
||||||
"""Average of Dice coefficient for all batches, or for a single mask.
|
def __init__(self, weight=None, size_average=True):
|
||||||
|
super(DiceLoss, self).__init__()
|
||||||
|
|
||||||
Args:
|
@staticmethod
|
||||||
input (Tensor): _description_
|
def coeff(inputs, targets, smooth=1):
|
||||||
target (Tensor): _description_
|
# comment out if your model contains a sigmoid or equivalent activation layer
|
||||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
inputs = torch.sigmoid(inputs)
|
||||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
|
||||||
|
|
||||||
Raises:
|
# flatten label and prediction tensors
|
||||||
ValueError: _description_
|
inputs = inputs.view(-1)
|
||||||
|
targets = targets.view(-1)
|
||||||
|
|
||||||
Returns:
|
intersection = (inputs * targets).sum()
|
||||||
float: _description_
|
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
|
||||||
"""
|
|
||||||
assert input.size() == target.size()
|
|
||||||
|
|
||||||
if input.dim() == 2 and reduce_batch_first:
|
def forward(self, inputs, targets, smooth=1):
|
||||||
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
return 1 - self.coeff(inputs, targets, smooth)
|
||||||
|
|
||||||
if input.dim() == 2 or reduce_batch_first:
|
|
||||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
|
||||||
sets_sum = torch.sum(input) + torch.sum(target)
|
|
||||||
|
|
||||||
if sets_sum.item() == 0:
|
|
||||||
sets_sum = 2 * inter
|
|
||||||
|
|
||||||
return (2 * inter + epsilon) / (sets_sum + epsilon)
|
|
||||||
else:
|
|
||||||
# compute and average metric for each batch element
|
|
||||||
dice = 0
|
|
||||||
|
|
||||||
for i in range(input.shape[0]):
|
|
||||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
|
||||||
|
|
||||||
return dice / input.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
|
||||||
"""Average of Dice coefficient for all classes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (Tensor): _description_
|
|
||||||
target (Tensor): _description_
|
|
||||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
|
||||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: _description_
|
|
||||||
"""
|
|
||||||
assert input.size() == target.size()
|
|
||||||
|
|
||||||
dice = 0
|
|
||||||
|
|
||||||
for channel in range(input.shape[1]):
|
|
||||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
|
||||||
|
|
||||||
return dice / input.shape[1]
|
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
|
|
||||||
"""Dice loss (objective to minimize) between 0 and 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (Tensor): _description_
|
|
||||||
target (Tensor): _description_
|
|
||||||
multiclass (bool, optional): _description_. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: _description_
|
|
||||||
"""
|
|
||||||
assert input.size() == target.size()
|
|
||||||
|
|
||||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
|
||||||
|
|
||||||
return 1 - fn(input, target, reduce_batch_first=True)
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform):
|
||||||
nb,
|
nb,
|
||||||
path_paste_img_dir,
|
path_paste_img_dir,
|
||||||
path_paste_mask_dir,
|
path_paste_mask_dir,
|
||||||
scale_range=(0.1, 0.2),
|
scale_range=(0.05, 0.25),
|
||||||
always_apply=True,
|
always_apply=True,
|
||||||
p=1.0,
|
p=1.0,
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in a new issue