feat: binarized the masks + lots of new metrics to fix

Former-commit-id: c840d14f722503d241f6bb6d899630ad6345aca0 [formerly e435a21234620add4f0e4e269a4141e5c1508cd9]
Former-commit-id: 8006af185fd68cc88b2305a02513106c16758d77
This commit is contained in:
Laurent Fainsin 2022-06-30 23:28:38 +02:00
parent e20a989c41
commit 7bdac6583b
5 changed files with 67 additions and 64 deletions

View file

@ -1 +1 @@
fb39f9a23b728fadb88ce579f78bb419ff0eaab6
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d

View file

@ -1,45 +0,0 @@
import numpy as np
import torch
from tqdm import tqdm
import wandb
from src.utils.dice import dice_coeff
class_labels = {
1: "sphere",
}
def evaluate(net, dataloader, device):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
for images, masks_true in dataloader:
# move images and labels to correct device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).float().to(device=device)
# forward, predict the mask
with torch.inference_mode():
masks_pred = net(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
# update progress bar
pbar.update(images.shape[0])
# save some images to wandb
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
wandb.log({"predictions_table": table})
net.train()
# Fixes a potential division by zero error
return dice_score / num_val_batches if num_val_batches else dice_score

View file

@ -8,9 +8,9 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from evaluate import evaluate
from src.utils.dataset import SphereDataset
from unet import UNet
from utils.dice import dice_coeff
from utils.paste import RandomPaste
@ -22,7 +22,7 @@ def main():
wandb.init(
project="U-Net",
config=dict(
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/smolval2017",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
@ -51,7 +51,7 @@ def main():
# 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
wandb.watch(net, log_freq=100)
wandb.watch(net, log_freq=100) # TODO: 1/4 epochs
# transfer network to device
net.to(device=device)
@ -110,10 +110,6 @@ def main():
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
# accuracy stuff
mse = torch.nn.MSELoss()
mae = torch.nn.L1Loss()
# save model.pth
torch.save(net.state_dict(), "checkpoints/model-0.pth")
artifact = wandb.Artifact("pth", type="model")
@ -136,6 +132,9 @@ def main():
"""
)
# setup wandb table for saving images
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
try:
for epoch in range(1, wandb.config.EPOCHS + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
@ -164,9 +163,9 @@ def main():
grad_scaler.update()
# compute metrics
accuracy = (true_masks == pred_masks).float().mean()
mse = torch.nn.functional.mse_loss(pred_masks, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks, true_masks)
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean()
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar
pbar.update(images.shape[0])
@ -177,23 +176,64 @@ def main():
{
"train/epoch": epoch - 1 + step / len(train_loader),
"train/accuracy": accuracy,
"train/loss": train_loss,
"train/mse": mse,
"train/bce": train_loss,
"train/mae": mae,
}
)
# Evaluation round
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
net.eval()
accuracy = 0
dice = 0
mae = 0
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
for images, masks_true in val_loader:
# transfer images to device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
# forward
with torch.inference_mode():
masks_pred = net(images)
# compute metrics
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
accuracy += (true_masks == pred_masks_bin).float().sum()
dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum")
# update progress bar
pbar.update(images.shape[0])
accuracy /= len(ds_valid)
dice /= len(val_loader) # TODO: fix dice_coeff to not average
mae /= len(ds_valid)
# save the last validation batch to table
for i, (img, mask, pred) in enumerate(
zip(
images.to("cpu"),
masks_true.to("cpu"),
masks_pred.to("cpu"),
)
):
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
# log validation metrics
wandb.log(
{
"val/val_score": val_score,
"val/predictions": table,
"val/accuracy": accuracy,
"val/dice": dice,
"val/mae": mae,
}
)
# update hyperparameters
net.train()
scheduler.step(dice)
# save weights when epoch end
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
artifact = wandb.Artifact("pth", type="model")

View file

@ -1,4 +1,3 @@
import logging
import os
import numpy as np
@ -24,6 +23,10 @@ class SphereDataset(Dataset):
if self.transform is not None:
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"].float()
mask = augmentations["mask"]
# make sure image and mask are floats
image = image.float()
mask = mask.float()
return image, mask

View file

@ -42,6 +42,7 @@ class RandomPaste(A.DualTransform):
# convert img to Image, needed for `paste` function
img = Image.fromarray(img)
# paste spheres
for pos in positions:
img.paste(paste_img, pos, paste_mask)
@ -51,8 +52,12 @@ class RandomPaste(A.DualTransform):
# convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask)
# binarize the mask -> {0, 1}
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
# paste spheres
for pos in positions:
mask.paste(paste_mask, pos, paste_mask)
mask.paste(paste_mask, pos, paste_mask_bin)
return np.asarray(mask.convert("L"))