feat: binarized the masks + lots of new metrics to fix
Former-commit-id: c840d14f722503d241f6bb6d899630ad6345aca0 [formerly e435a21234620add4f0e4e269a4141e5c1508cd9] Former-commit-id: 8006af185fd68cc88b2305a02513106c16758d77
This commit is contained in:
parent
e20a989c41
commit
7bdac6583b
|
@ -1 +1 @@
|
|||
fb39f9a23b728fadb88ce579f78bb419ff0eaab6
|
||||
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d
|
|
@ -1,45 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from src.utils.dice import dice_coeff
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
}
|
||||
|
||||
|
||||
def evaluate(net, dataloader, device):
|
||||
net.eval()
|
||||
num_val_batches = len(dataloader)
|
||||
dice_score = 0
|
||||
|
||||
# iterate over the validation set
|
||||
with tqdm(dataloader, total=len(dataloader.dataset), desc="val", unit="img", leave=False) as pbar:
|
||||
for images, masks_true in dataloader:
|
||||
# move images and labels to correct device
|
||||
images = images.to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).float().to(device=device)
|
||||
|
||||
# forward, predict the mask
|
||||
with torch.inference_mode():
|
||||
masks_pred = net(images)
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
|
||||
# compute the Dice score
|
||||
dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(images.shape[0])
|
||||
|
||||
# save some images to wandb
|
||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
|
||||
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||
wandb.log({"predictions_table": table})
|
||||
|
||||
net.train()
|
||||
|
||||
# Fixes a potential division by zero error
|
||||
return dice_score / num_val_batches if num_val_batches else dice_score
|
70
src/train.py
70
src/train.py
|
@ -8,9 +8,9 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
from evaluate import evaluate
|
||||
from src.utils.dataset import SphereDataset
|
||||
from unet import UNet
|
||||
from utils.dice import dice_coeff
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ def main():
|
|||
wandb.init(
|
||||
project="U-Net",
|
||||
config=dict(
|
||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
|
||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/smolval2017",
|
||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
|
@ -51,7 +51,7 @@ def main():
|
|||
# 0. Create network
|
||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
wandb.watch(net, log_freq=100)
|
||||
wandb.watch(net, log_freq=100) # TODO: 1/4 epochs
|
||||
|
||||
# transfer network to device
|
||||
net.to(device=device)
|
||||
|
@ -110,10 +110,6 @@ def main():
|
|||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
# accuracy stuff
|
||||
mse = torch.nn.MSELoss()
|
||||
mae = torch.nn.L1Loss()
|
||||
|
||||
# save model.pth
|
||||
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
|
@ -136,6 +132,9 @@ def main():
|
|||
"""
|
||||
)
|
||||
|
||||
# setup wandb table for saving images
|
||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||
|
||||
try:
|
||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||
|
@ -164,9 +163,9 @@ def main():
|
|||
grad_scaler.update()
|
||||
|
||||
# compute metrics
|
||||
accuracy = (true_masks == pred_masks).float().mean()
|
||||
mse = torch.nn.functional.mse_loss(pred_masks, true_masks)
|
||||
mae = torch.nn.functional.l1_loss(pred_masks, true_masks)
|
||||
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
||||
accuracy = (true_masks == pred_masks_bin).float().mean()
|
||||
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||
|
||||
# update tqdm progress bar
|
||||
pbar.update(images.shape[0])
|
||||
|
@ -177,23 +176,64 @@ def main():
|
|||
{
|
||||
"train/epoch": epoch - 1 + step / len(train_loader),
|
||||
"train/accuracy": accuracy,
|
||||
"train/loss": train_loss,
|
||||
"train/mse": mse,
|
||||
"train/bce": train_loss,
|
||||
"train/mae": mae,
|
||||
}
|
||||
)
|
||||
|
||||
# Evaluation round
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
net.eval()
|
||||
accuracy = 0
|
||||
dice = 0
|
||||
mae = 0
|
||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
|
||||
for images, masks_true in val_loader:
|
||||
|
||||
# transfer images to device
|
||||
images = images.to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||
|
||||
# forward
|
||||
with torch.inference_mode():
|
||||
masks_pred = net(images)
|
||||
|
||||
# compute metrics
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
accuracy += (true_masks == pred_masks_bin).float().sum()
|
||||
dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
||||
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum")
|
||||
|
||||
# update progress bar
|
||||
pbar.update(images.shape[0])
|
||||
|
||||
accuracy /= len(ds_valid)
|
||||
dice /= len(val_loader) # TODO: fix dice_coeff to not average
|
||||
mae /= len(ds_valid)
|
||||
|
||||
# save the last validation batch to table
|
||||
for i, (img, mask, pred) in enumerate(
|
||||
zip(
|
||||
images.to("cpu"),
|
||||
masks_true.to("cpu"),
|
||||
masks_pred.to("cpu"),
|
||||
)
|
||||
):
|
||||
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||
|
||||
# log validation metrics
|
||||
wandb.log(
|
||||
{
|
||||
"val/val_score": val_score,
|
||||
"val/predictions": table,
|
||||
"val/accuracy": accuracy,
|
||||
"val/dice": dice,
|
||||
"val/mae": mae,
|
||||
}
|
||||
)
|
||||
|
||||
# update hyperparameters
|
||||
net.train()
|
||||
scheduler.step(dice)
|
||||
|
||||
# save weights when epoch end
|
||||
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
@ -24,6 +23,10 @@ class SphereDataset(Dataset):
|
|||
if self.transform is not None:
|
||||
augmentations = self.transform(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"].float()
|
||||
mask = augmentations["mask"]
|
||||
|
||||
# make sure image and mask are floats
|
||||
image = image.float()
|
||||
mask = mask.float()
|
||||
|
||||
return image, mask
|
||||
|
|
|
@ -42,6 +42,7 @@ class RandomPaste(A.DualTransform):
|
|||
# convert img to Image, needed for `paste` function
|
||||
img = Image.fromarray(img)
|
||||
|
||||
# paste spheres
|
||||
for pos in positions:
|
||||
img.paste(paste_img, pos, paste_mask)
|
||||
|
||||
|
@ -51,8 +52,12 @@ class RandomPaste(A.DualTransform):
|
|||
# convert mask to Image, needed for `paste` function
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
# binarize the mask -> {0, 1}
|
||||
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
|
||||
|
||||
# paste spheres
|
||||
for pos in positions:
|
||||
mask.paste(paste_mask, pos, paste_mask)
|
||||
mask.paste(paste_mask, pos, paste_mask_bin)
|
||||
|
||||
return np.asarray(mask.convert("L"))
|
||||
|
||||
|
|
Loading…
Reference in a new issue