feat: binarized the masks + lots of new metrics to fix

Former-commit-id: c840d14f722503d241f6bb6d899630ad6345aca0 [formerly e435a21234620add4f0e4e269a4141e5c1508cd9]
Former-commit-id: 8006af185fd68cc88b2305a02513106c16758d77
This commit is contained in:
Laurent Fainsin 2022-06-30 23:28:38 +02:00
parent e20a989c41
commit 7bdac6583b
5 changed files with 67 additions and 64 deletions

View file

@ -1 +1 @@
fb39f9a23b728fadb88ce579f78bb419ff0eaab6 9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d

View file

@ -1,45 +0,0 @@
import numpy as np
import torch
from tqdm import tqdm
import wandb
from src.utils.dice import dice_coeff
class_labels = {
1: "sphere",
}
def evaluate(net, dataloader, device):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
for images, masks_true in dataloader:
# move images and labels to correct device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).float().to(device=device)
# forward, predict the mask
with torch.inference_mode():
masks_pred = net(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
# update progress bar
pbar.update(images.shape[0])
# save some images to wandb
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
wandb.log({"predictions_table": table})
net.train()
# Fixes a potential division by zero error
return dice_score / num_val_batches if num_val_batches else dice_score

View file

@ -8,9 +8,9 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import wandb import wandb
from evaluate import evaluate
from src.utils.dataset import SphereDataset from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.dice import dice_coeff
from utils.paste import RandomPaste from utils.paste import RandomPaste
@ -22,7 +22,7 @@ def main():
wandb.init( wandb.init(
project="U-Net", project="U-Net",
config=dict( config=dict(
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017", DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/smolval2017",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
@ -51,7 +51,7 @@ def main():
# 0. Create network # 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
wandb.watch(net, log_freq=100) wandb.watch(net, log_freq=100) # TODO: 1/4 epochs
# transfer network to device # transfer network to device
net.to(device=device) net.to(device=device)
@ -110,10 +110,6 @@ def main():
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
# accuracy stuff
mse = torch.nn.MSELoss()
mae = torch.nn.L1Loss()
# save model.pth # save model.pth
torch.save(net.state_dict(), "checkpoints/model-0.pth") torch.save(net.state_dict(), "checkpoints/model-0.pth")
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("pth", type="model")
@ -136,6 +132,9 @@ def main():
""" """
) )
# setup wandb table for saving images
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
try: try:
for epoch in range(1, wandb.config.EPOCHS + 1): for epoch in range(1, wandb.config.EPOCHS + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
@ -164,9 +163,9 @@ def main():
grad_scaler.update() grad_scaler.update()
# compute metrics # compute metrics
accuracy = (true_masks == pred_masks).float().mean() pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
mse = torch.nn.functional.mse_loss(pred_masks, true_masks) accuracy = (true_masks == pred_masks_bin).float().mean()
mae = torch.nn.functional.l1_loss(pred_masks, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar # update tqdm progress bar
pbar.update(images.shape[0]) pbar.update(images.shape[0])
@ -177,23 +176,64 @@ def main():
{ {
"train/epoch": epoch - 1 + step / len(train_loader), "train/epoch": epoch - 1 + step / len(train_loader),
"train/accuracy": accuracy, "train/accuracy": accuracy,
"train/loss": train_loss, "train/bce": train_loss,
"train/mse": mse,
"train/mae": mae, "train/mae": mae,
} }
) )
# Evaluation round # Evaluation round
val_score = evaluate(net, val_loader, device) net.eval()
scheduler.step(val_score) accuracy = 0
dice = 0
mae = 0
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
for images, masks_true in val_loader:
# transfer images to device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
# forward
with torch.inference_mode():
masks_pred = net(images)
# compute metrics
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
accuracy += (true_masks == pred_masks_bin).float().sum()
dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum")
# update progress bar
pbar.update(images.shape[0])
accuracy /= len(ds_valid)
dice /= len(val_loader) # TODO: fix dice_coeff to not average
mae /= len(ds_valid)
# save the last validation batch to table
for i, (img, mask, pred) in enumerate(
zip(
images.to("cpu"),
masks_true.to("cpu"),
masks_pred.to("cpu"),
)
):
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
# log validation metrics # log validation metrics
wandb.log( wandb.log(
{ {
"val/val_score": val_score, "val/predictions": table,
"val/accuracy": accuracy,
"val/dice": dice,
"val/mae": mae,
} }
) )
# update hyperparameters
net.train()
scheduler.step(dice)
# save weights when epoch end # save weights when epoch end
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth") torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("pth", type="model")

View file

@ -1,4 +1,3 @@
import logging
import os import os
import numpy as np import numpy as np
@ -24,6 +23,10 @@ class SphereDataset(Dataset):
if self.transform is not None: if self.transform is not None:
augmentations = self.transform(image=image, mask=mask) augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"] image = augmentations["image"]
mask = augmentations["mask"].float() mask = augmentations["mask"]
# make sure image and mask are floats
image = image.float()
mask = mask.float()
return image, mask return image, mask

View file

@ -42,6 +42,7 @@ class RandomPaste(A.DualTransform):
# convert img to Image, needed for `paste` function # convert img to Image, needed for `paste` function
img = Image.fromarray(img) img = Image.fromarray(img)
# paste spheres
for pos in positions: for pos in positions:
img.paste(paste_img, pos, paste_mask) img.paste(paste_img, pos, paste_mask)
@ -51,8 +52,12 @@ class RandomPaste(A.DualTransform):
# convert mask to Image, needed for `paste` function # convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask) mask = Image.fromarray(mask)
# binarize the mask -> {0, 1}
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
# paste spheres
for pos in positions: for pos in positions:
mask.paste(paste_mask, pos, paste_mask) mask.paste(paste_mask, pos, paste_mask_bin)
return np.asarray(mask.convert("L")) return np.asarray(mask.convert("L"))