feat: binarized the masks + lots of new metrics to fix
Former-commit-id: c840d14f722503d241f6bb6d899630ad6345aca0 [formerly e435a21234620add4f0e4e269a4141e5c1508cd9] Former-commit-id: 8006af185fd68cc88b2305a02513106c16758d77
This commit is contained in:
parent
e20a989c41
commit
7bdac6583b
|
@ -1 +1 @@
|
||||||
fb39f9a23b728fadb88ce579f78bb419ff0eaab6
|
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d
|
|
@ -1,45 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import wandb
|
|
||||||
from src.utils.dice import dice_coeff
|
|
||||||
|
|
||||||
class_labels = {
|
|
||||||
1: "sphere",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(net, dataloader, device):
|
|
||||||
net.eval()
|
|
||||||
num_val_batches = len(dataloader)
|
|
||||||
dice_score = 0
|
|
||||||
|
|
||||||
# iterate over the validation set
|
|
||||||
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
|
|
||||||
for images, masks_true in dataloader:
|
|
||||||
# move images and labels to correct device
|
|
||||||
images = images.to(device=device)
|
|
||||||
masks_true = masks_true.unsqueeze(1).float().to(device=device)
|
|
||||||
|
|
||||||
# forward, predict the mask
|
|
||||||
with torch.inference_mode():
|
|
||||||
masks_pred = net(images)
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
|
|
||||||
# compute the Dice score
|
|
||||||
dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
|
||||||
|
|
||||||
# update progress bar
|
|
||||||
pbar.update(images.shape[0])
|
|
||||||
|
|
||||||
# save some images to wandb
|
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
|
||||||
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
|
|
||||||
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
|
||||||
wandb.log({"predictions_table": table})
|
|
||||||
|
|
||||||
net.train()
|
|
||||||
|
|
||||||
# Fixes a potential division by zero error
|
|
||||||
return dice_score / num_val_batches if num_val_batches else dice_score
|
|
70
src/train.py
70
src/train.py
|
@ -8,9 +8,9 @@ from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from evaluate import evaluate
|
|
||||||
from src.utils.dataset import SphereDataset
|
from src.utils.dataset import SphereDataset
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
|
from utils.dice import dice_coeff
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ def main():
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project="U-Net",
|
project="U-Net",
|
||||||
config=dict(
|
config=dict(
|
||||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
|
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/smolval2017",
|
||||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
|
@ -51,7 +51,7 @@ def main():
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
wandb.watch(net, log_freq=100)
|
wandb.watch(net, log_freq=100) # TODO: 1/4 epochs
|
||||||
|
|
||||||
# transfer network to device
|
# transfer network to device
|
||||||
net.to(device=device)
|
net.to(device=device)
|
||||||
|
@ -110,10 +110,6 @@ def main():
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# accuracy stuff
|
|
||||||
mse = torch.nn.MSELoss()
|
|
||||||
mae = torch.nn.L1Loss()
|
|
||||||
|
|
||||||
# save model.pth
|
# save model.pth
|
||||||
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
|
@ -136,6 +132,9 @@ def main():
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# setup wandb table for saving images
|
||||||
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||||
|
@ -164,9 +163,9 @@ def main():
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
accuracy = (true_masks == pred_masks).float().mean()
|
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
||||||
mse = torch.nn.functional.mse_loss(pred_masks, true_masks)
|
accuracy = (true_masks == pred_masks_bin).float().mean()
|
||||||
mae = torch.nn.functional.l1_loss(pred_masks, true_masks)
|
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||||
|
|
||||||
# update tqdm progress bar
|
# update tqdm progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
@ -177,23 +176,64 @@ def main():
|
||||||
{
|
{
|
||||||
"train/epoch": epoch - 1 + step / len(train_loader),
|
"train/epoch": epoch - 1 + step / len(train_loader),
|
||||||
"train/accuracy": accuracy,
|
"train/accuracy": accuracy,
|
||||||
"train/loss": train_loss,
|
"train/bce": train_loss,
|
||||||
"train/mse": mse,
|
|
||||||
"train/mae": mae,
|
"train/mae": mae,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
val_score = evaluate(net, val_loader, device)
|
net.eval()
|
||||||
scheduler.step(val_score)
|
accuracy = 0
|
||||||
|
dice = 0
|
||||||
|
mae = 0
|
||||||
|
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
|
||||||
|
for images, masks_true in val_loader:
|
||||||
|
|
||||||
|
# transfer images to device
|
||||||
|
images = images.to(device=device)
|
||||||
|
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
with torch.inference_mode():
|
||||||
|
masks_pred = net(images)
|
||||||
|
|
||||||
|
# compute metrics
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
accuracy += (true_masks == pred_masks_bin).float().sum()
|
||||||
|
dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
||||||
|
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum")
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
|
accuracy /= len(ds_valid)
|
||||||
|
dice /= len(val_loader) # TODO: fix dice_coeff to not average
|
||||||
|
mae /= len(ds_valid)
|
||||||
|
|
||||||
|
# save the last validation batch to table
|
||||||
|
for i, (img, mask, pred) in enumerate(
|
||||||
|
zip(
|
||||||
|
images.to("cpu"),
|
||||||
|
masks_true.to("cpu"),
|
||||||
|
masks_pred.to("cpu"),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||||
|
|
||||||
# log validation metrics
|
# log validation metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"val/val_score": val_score,
|
"val/predictions": table,
|
||||||
|
"val/accuracy": accuracy,
|
||||||
|
"val/dice": dice,
|
||||||
|
"val/mae": mae,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# update hyperparameters
|
||||||
|
net.train()
|
||||||
|
scheduler.step(dice)
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,6 +23,10 @@ class SphereDataset(Dataset):
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
augmentations = self.transform(image=image, mask=mask)
|
augmentations = self.transform(image=image, mask=mask)
|
||||||
image = augmentations["image"]
|
image = augmentations["image"]
|
||||||
mask = augmentations["mask"].float()
|
mask = augmentations["mask"]
|
||||||
|
|
||||||
|
# make sure image and mask are floats
|
||||||
|
image = image.float()
|
||||||
|
mask = mask.float()
|
||||||
|
|
||||||
return image, mask
|
return image, mask
|
||||||
|
|
|
@ -42,6 +42,7 @@ class RandomPaste(A.DualTransform):
|
||||||
# convert img to Image, needed for `paste` function
|
# convert img to Image, needed for `paste` function
|
||||||
img = Image.fromarray(img)
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
|
# paste spheres
|
||||||
for pos in positions:
|
for pos in positions:
|
||||||
img.paste(paste_img, pos, paste_mask)
|
img.paste(paste_img, pos, paste_mask)
|
||||||
|
|
||||||
|
@ -51,8 +52,12 @@ class RandomPaste(A.DualTransform):
|
||||||
# convert mask to Image, needed for `paste` function
|
# convert mask to Image, needed for `paste` function
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask)
|
||||||
|
|
||||||
|
# binarize the mask -> {0, 1}
|
||||||
|
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
|
||||||
|
|
||||||
|
# paste spheres
|
||||||
for pos in positions:
|
for pos in positions:
|
||||||
mask.paste(paste_mask, pos, paste_mask)
|
mask.paste(paste_mask, pos, paste_mask_bin)
|
||||||
|
|
||||||
return np.asarray(mask.convert("L"))
|
return np.asarray(mask.convert("L"))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue