feat: using dice_loss

feat: paste aug contrast/sharpness

Former-commit-id: 93f19e9643858a81ace14e9a697dfb6b3cca4d47 [formerly f6ef5f65e84f37b4b55a99a49442b7d30d6d3911]
Former-commit-id: 2f49a81340a91ab7456d093a849ed294457f8a83
This commit is contained in:
Laurent Fainsin 2022-07-06 14:27:26 +02:00
parent 5a74af6cdb
commit b71b57285f
4 changed files with 46 additions and 137 deletions

View file

@ -23,7 +23,7 @@ CONFIG = {
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 8,
"EPOCHS": 5,
"EPOCHS": 10,
"BATCH_SIZE": 16,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,

View file

@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
import wandb
from src.utils.dataset import SphereDataset
from utils.dice import dice_coeff
from utils.dice import dice_loss
from utils.paste import RandomPaste
from .blocks import *
@ -111,28 +111,29 @@ class UNet(pl.LightningModule):
# forward pass
masks_pred = self(images)
# compute loss
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
# compute other metrics
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
self.log_dict(
{
"train/accuracy": accuracy,
"train/bce": bce,
"train/dice": dice,
"train/dice_bin": dice_bin,
"train/bce": bce,
"train/mae": mae,
},
)
return dict(
loss=bce,
dice=dice,
accuracy=accuracy,
loss=dice,
bce=bce,
mae=mae,
)
@ -144,17 +145,17 @@ class UNet(pl.LightningModule):
# forward pass
masks_pred = self(images)
# compute loss
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
# compute other metrics
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
rows = []
if batch_idx < 6:
if batch_idx % 50 == 0:
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
@ -181,9 +182,10 @@ class UNet(pl.LightningModule):
)
return dict(
loss=bce,
dice=dice,
accuracy=accuracy,
loss=dice,
dice_bin=dice_bin,
bce=bce,
mae=mae,
table_rows=rows,
)
@ -191,8 +193,9 @@ class UNet(pl.LightningModule):
def validation_epoch_end(self, validation_outputs):
# matrics unpacking
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
dice = torch.stack([d["dice"] for d in validation_outputs]).mean()
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
# table unpacking
@ -201,7 +204,7 @@ class UNet(pl.LightningModule):
rows = list(itertools.chain.from_iterable(rowss))
# logging
try:
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="val/predictions",
columns=columns,
@ -209,11 +212,13 @@ class UNet(pl.LightningModule):
)
except:
pass
self.log_dict(
{
"val/accuracy": accuracy,
"val/bce": loss,
"val/dice": dice,
"val/dice": loss,
"val/dice_bin": dice_bin,
"val/bce": bce,
"val/mae": mae,
}
)
@ -231,48 +236,6 @@ class UNet(pl.LightningModule):
artifact.add_file(f"checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
# def test_step(self, batch, batch_idx):
# # unpacking
# images, masks_true = batch
# masks_true = masks_true.unsqueeze(1)
# masks_pred = self(images)
# masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# # compute metrics
# loss = F.cross_entropy(masks_pred, masks_true)
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
# accuracy = (masks_true == masks_pred_bin).float().mean()
# dice = dice_coeff(masks_pred_bin, masks_true)
# if batch_idx == 0:
# self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
# return loss, dice, accuracy, mae
# def test_step_end(self, test_outputs):
# # unpacking
# list_loss, list_dice, list_accuracy, list_mae = test_outputs
# # averaging
# loss = np.mean(list_loss)
# dice = np.mean(list_dice)
# accuracy = np.mean(list_accuracy)
# mae = np.mean(list_mae)
# # # get learning rate
# # optimizer = self.optimizers[0]
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
# wandb.log(
# {
# # "train/learning_rate": learning_rate,
# "test/accuracy": accuracy,
# "test/bce": loss,
# "test/dice": dice,
# "test/mae": mae,
# }
# )
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),

View file

@ -1,80 +1,18 @@
import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all batches, or for a single mask.
def dice_score(inputs, targets, smooth=1, logits=True):
# comment out if your model contains a sigmoid or equivalent activation layer
if logits:
inputs = torch.sigmoid(inputs)
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
# flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
Raises:
ValueError: _description_
Returns:
float: _description_
"""
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
intersection = (inputs * targets).sum()
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all classes.
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Returns:
float: _description_
"""
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
"""Dice loss (objective to minimize) between 0 and 1.
Args:
input (Tensor): _description_
target (Tensor): _description_
multiclass (bool, optional): _description_. Defaults to False.
Returns:
float: _description_
"""
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
def dice_loss(inputs, targets, smooth=1, logits=True):
return 1 - dice_score(inputs, targets, smooth, logits)

View file

@ -93,6 +93,18 @@ class RandomPaste(A.DualTransform):
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint)
# change paste_img's brightness randomly
filter = ImageEnhance.Brightness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# change paste_img's contrast randomly
filter = ImageEnhance.Contrast(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# change paste_img's sharpness randomly
filter = ImageEnhance.Sharpness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# compute the minimum scaling to fit inside target image
min_scale = np.min(target_shape / paste_shape)
@ -117,10 +129,6 @@ class RandomPaste(A.DualTransform):
# update paste_shape after scaling
paste_shape = np.array(paste_img.size, dtype=np.uint)
# change brightness randomly
filter = ImageEnhance.Brightness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# generate some positions
positions = []
NB = rd.randint(1, self.nb)