feat: using dice_loss
feat: paste aug contrast/sharpness Former-commit-id: 93f19e9643858a81ace14e9a697dfb6b3cca4d47 [formerly f6ef5f65e84f37b4b55a99a49442b7d30d6d3911] Former-commit-id: 2f49a81340a91ab7456d093a849ed294457f8a83
This commit is contained in:
parent
5a74af6cdb
commit
b71b57285f
|
@ -23,7 +23,7 @@ CONFIG = {
|
|||
"BENCHMARK": True,
|
||||
"DEVICE": "gpu",
|
||||
"WORKERS": 8,
|
||||
"EPOCHS": 5,
|
||||
"EPOCHS": 10,
|
||||
"BATCH_SIZE": 16,
|
||||
"LEARNING_RATE": 1e-4,
|
||||
"WEIGHT_DECAY": 1e-8,
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import wandb
|
||||
from src.utils.dataset import SphereDataset
|
||||
from utils.dice import dice_coeff
|
||||
from utils.dice import dice_loss
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
from .blocks import *
|
||||
|
@ -111,28 +111,29 @@ class UNet(pl.LightningModule):
|
|||
# forward pass
|
||||
masks_pred = self(images)
|
||||
|
||||
# compute loss
|
||||
# compute metrics
|
||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||
dice = dice_loss(masks_pred, masks_true)
|
||||
|
||||
# compute other metrics
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
self.log_dict(
|
||||
{
|
||||
"train/accuracy": accuracy,
|
||||
"train/bce": bce,
|
||||
"train/dice": dice,
|
||||
"train/dice_bin": dice_bin,
|
||||
"train/bce": bce,
|
||||
"train/mae": mae,
|
||||
},
|
||||
)
|
||||
|
||||
return dict(
|
||||
loss=bce,
|
||||
dice=dice,
|
||||
accuracy=accuracy,
|
||||
loss=dice,
|
||||
bce=bce,
|
||||
mae=mae,
|
||||
)
|
||||
|
||||
|
@ -144,17 +145,17 @@ class UNet(pl.LightningModule):
|
|||
# forward pass
|
||||
masks_pred = self(images)
|
||||
|
||||
# compute loss
|
||||
# compute metrics
|
||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||
dice = dice_loss(masks_pred, masks_true)
|
||||
|
||||
# compute other metrics
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
rows = []
|
||||
if batch_idx < 6:
|
||||
if batch_idx % 50 == 0:
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
|
@ -181,9 +182,10 @@ class UNet(pl.LightningModule):
|
|||
)
|
||||
|
||||
return dict(
|
||||
loss=bce,
|
||||
dice=dice,
|
||||
accuracy=accuracy,
|
||||
loss=dice,
|
||||
dice_bin=dice_bin,
|
||||
bce=bce,
|
||||
mae=mae,
|
||||
table_rows=rows,
|
||||
)
|
||||
|
@ -191,8 +193,9 @@ class UNet(pl.LightningModule):
|
|||
def validation_epoch_end(self, validation_outputs):
|
||||
# matrics unpacking
|
||||
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
||||
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
|
||||
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
||||
dice = torch.stack([d["dice"] for d in validation_outputs]).mean()
|
||||
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
|
||||
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
||||
|
||||
# table unpacking
|
||||
|
@ -201,7 +204,7 @@ class UNet(pl.LightningModule):
|
|||
rows = list(itertools.chain.from_iterable(rowss))
|
||||
|
||||
# logging
|
||||
try:
|
||||
try: # required by autofinding, logger replaced by dummy
|
||||
self.logger.log_table(
|
||||
key="val/predictions",
|
||||
columns=columns,
|
||||
|
@ -209,11 +212,13 @@ class UNet(pl.LightningModule):
|
|||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.log_dict(
|
||||
{
|
||||
"val/accuracy": accuracy,
|
||||
"val/bce": loss,
|
||||
"val/dice": dice,
|
||||
"val/dice": loss,
|
||||
"val/dice_bin": dice_bin,
|
||||
"val/bce": bce,
|
||||
"val/mae": mae,
|
||||
}
|
||||
)
|
||||
|
@ -231,48 +236,6 @@ class UNet(pl.LightningModule):
|
|||
artifact.add_file(f"checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# def test_step(self, batch, batch_idx):
|
||||
# # unpacking
|
||||
# images, masks_true = batch
|
||||
# masks_true = masks_true.unsqueeze(1)
|
||||
# masks_pred = self(images)
|
||||
# masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
|
||||
# # compute metrics
|
||||
# loss = F.cross_entropy(masks_pred, masks_true)
|
||||
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
# dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
# if batch_idx == 0:
|
||||
# self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
||||
|
||||
# return loss, dice, accuracy, mae
|
||||
|
||||
# def test_step_end(self, test_outputs):
|
||||
# # unpacking
|
||||
# list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
||||
|
||||
# # averaging
|
||||
# loss = np.mean(list_loss)
|
||||
# dice = np.mean(list_dice)
|
||||
# accuracy = np.mean(list_accuracy)
|
||||
# mae = np.mean(list_mae)
|
||||
|
||||
# # # get learning rate
|
||||
# # optimizer = self.optimizers[0]
|
||||
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
|
||||
# wandb.log(
|
||||
# {
|
||||
# # "train/learning_rate": learning_rate,
|
||||
# "test/accuracy": accuracy,
|
||||
# "test/bce": loss,
|
||||
# "test/dice": dice,
|
||||
# "test/mae": mae,
|
||||
# }
|
||||
# )
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
|
|
|
@ -1,80 +1,18 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||
"""Average of Dice coefficient for all batches, or for a single mask.
|
||||
def dice_score(inputs, targets, smooth=1, logits=True):
|
||||
# comment out if your model contains a sigmoid or equivalent activation layer
|
||||
if logits:
|
||||
inputs = torch.sigmoid(inputs)
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||
# flatten label and prediction tensors
|
||||
inputs = inputs.view(-1)
|
||||
targets = targets.view(-1)
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
if input.dim() == 2 and reduce_batch_first:
|
||||
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
||||
|
||||
if input.dim() == 2 or reduce_batch_first:
|
||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
||||
sets_sum = torch.sum(input) + torch.sum(target)
|
||||
|
||||
if sets_sum.item() == 0:
|
||||
sets_sum = 2 * inter
|
||||
|
||||
return (2 * inter + epsilon) / (sets_sum + epsilon)
|
||||
else:
|
||||
# compute and average metric for each batch element
|
||||
dice = 0
|
||||
|
||||
for i in range(input.shape[0]):
|
||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
||||
|
||||
return dice / input.shape[0]
|
||||
intersection = (inputs * targets).sum()
|
||||
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
|
||||
|
||||
|
||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||
"""Average of Dice coefficient for all classes.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
dice = 0
|
||||
|
||||
for channel in range(input.shape[1]):
|
||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
||||
|
||||
return dice / input.shape[1]
|
||||
|
||||
|
||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
|
||||
"""Dice loss (objective to minimize) between 0 and 1.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
multiclass (bool, optional): _description_. Defaults to False.
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
||||
|
||||
return 1 - fn(input, target, reduce_batch_first=True)
|
||||
def dice_loss(inputs, targets, smooth=1, logits=True):
|
||||
return 1 - dice_score(inputs, targets, smooth, logits)
|
||||
|
|
|
@ -93,6 +93,18 @@ class RandomPaste(A.DualTransform):
|
|||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
|
||||
# change paste_img's brightness randomly
|
||||
filter = ImageEnhance.Brightness(paste_img)
|
||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||
|
||||
# change paste_img's contrast randomly
|
||||
filter = ImageEnhance.Contrast(paste_img)
|
||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||
|
||||
# change paste_img's sharpness randomly
|
||||
filter = ImageEnhance.Sharpness(paste_img)
|
||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||
|
||||
# compute the minimum scaling to fit inside target image
|
||||
min_scale = np.min(target_shape / paste_shape)
|
||||
|
||||
|
@ -117,10 +129,6 @@ class RandomPaste(A.DualTransform):
|
|||
# update paste_shape after scaling
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
|
||||
# change brightness randomly
|
||||
filter = ImageEnhance.Brightness(paste_img)
|
||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||
|
||||
# generate some positions
|
||||
positions = []
|
||||
NB = rd.randint(1, self.nb)
|
||||
|
|
Loading…
Reference in a new issue