feat: using dice_loss
feat: paste aug contrast/sharpness Former-commit-id: 93f19e9643858a81ace14e9a697dfb6b3cca4d47 [formerly f6ef5f65e84f37b4b55a99a49442b7d30d6d3911] Former-commit-id: 2f49a81340a91ab7456d093a849ed294457f8a83
This commit is contained in:
parent
5a74af6cdb
commit
b71b57285f
|
@ -23,7 +23,7 @@ CONFIG = {
|
||||||
"BENCHMARK": True,
|
"BENCHMARK": True,
|
||||||
"DEVICE": "gpu",
|
"DEVICE": "gpu",
|
||||||
"WORKERS": 8,
|
"WORKERS": 8,
|
||||||
"EPOCHS": 5,
|
"EPOCHS": 10,
|
||||||
"BATCH_SIZE": 16,
|
"BATCH_SIZE": 16,
|
||||||
"LEARNING_RATE": 1e-4,
|
"LEARNING_RATE": 1e-4,
|
||||||
"WEIGHT_DECAY": 1e-8,
|
"WEIGHT_DECAY": 1e-8,
|
||||||
|
|
|
@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.utils.dataset import SphereDataset
|
from src.utils.dataset import SphereDataset
|
||||||
from utils.dice import dice_coeff
|
from utils.dice import dice_loss
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
from .blocks import *
|
from .blocks import *
|
||||||
|
@ -111,28 +111,29 @@ class UNet(pl.LightningModule):
|
||||||
# forward pass
|
# forward pass
|
||||||
masks_pred = self(images)
|
masks_pred = self(images)
|
||||||
|
|
||||||
# compute loss
|
# compute metrics
|
||||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||||
|
dice = dice_loss(masks_pred, masks_true)
|
||||||
|
|
||||||
# compute other metrics
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
self.log_dict(
|
self.log_dict(
|
||||||
{
|
{
|
||||||
"train/accuracy": accuracy,
|
"train/accuracy": accuracy,
|
||||||
"train/bce": bce,
|
|
||||||
"train/dice": dice,
|
"train/dice": dice,
|
||||||
|
"train/dice_bin": dice_bin,
|
||||||
|
"train/bce": bce,
|
||||||
"train/mae": mae,
|
"train/mae": mae,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
loss=bce,
|
|
||||||
dice=dice,
|
|
||||||
accuracy=accuracy,
|
accuracy=accuracy,
|
||||||
|
loss=dice,
|
||||||
|
bce=bce,
|
||||||
mae=mae,
|
mae=mae,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,17 +145,17 @@ class UNet(pl.LightningModule):
|
||||||
# forward pass
|
# forward pass
|
||||||
masks_pred = self(images)
|
masks_pred = self(images)
|
||||||
|
|
||||||
# compute loss
|
# compute metrics
|
||||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||||
|
dice = dice_loss(masks_pred, masks_true)
|
||||||
|
|
||||||
# compute other metrics
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
if batch_idx < 6:
|
if batch_idx % 50 == 0:
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.cpu(),
|
images.cpu(),
|
||||||
|
@ -181,9 +182,10 @@ class UNet(pl.LightningModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
loss=bce,
|
|
||||||
dice=dice,
|
|
||||||
accuracy=accuracy,
|
accuracy=accuracy,
|
||||||
|
loss=dice,
|
||||||
|
dice_bin=dice_bin,
|
||||||
|
bce=bce,
|
||||||
mae=mae,
|
mae=mae,
|
||||||
table_rows=rows,
|
table_rows=rows,
|
||||||
)
|
)
|
||||||
|
@ -191,8 +193,9 @@ class UNet(pl.LightningModule):
|
||||||
def validation_epoch_end(self, validation_outputs):
|
def validation_epoch_end(self, validation_outputs):
|
||||||
# matrics unpacking
|
# matrics unpacking
|
||||||
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
||||||
|
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
|
||||||
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
||||||
dice = torch.stack([d["dice"] for d in validation_outputs]).mean()
|
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
|
||||||
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
||||||
|
|
||||||
# table unpacking
|
# table unpacking
|
||||||
|
@ -201,7 +204,7 @@ class UNet(pl.LightningModule):
|
||||||
rows = list(itertools.chain.from_iterable(rowss))
|
rows = list(itertools.chain.from_iterable(rowss))
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
try:
|
try: # required by autofinding, logger replaced by dummy
|
||||||
self.logger.log_table(
|
self.logger.log_table(
|
||||||
key="val/predictions",
|
key="val/predictions",
|
||||||
columns=columns,
|
columns=columns,
|
||||||
|
@ -209,11 +212,13 @@ class UNet(pl.LightningModule):
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.log_dict(
|
self.log_dict(
|
||||||
{
|
{
|
||||||
"val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
"val/bce": loss,
|
"val/dice": loss,
|
||||||
"val/dice": dice,
|
"val/dice_bin": dice_bin,
|
||||||
|
"val/bce": bce,
|
||||||
"val/mae": mae,
|
"val/mae": mae,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -231,48 +236,6 @@ class UNet(pl.LightningModule):
|
||||||
artifact.add_file(f"checkpoints/model.onnx")
|
artifact.add_file(f"checkpoints/model.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# def test_step(self, batch, batch_idx):
|
|
||||||
# # unpacking
|
|
||||||
# images, masks_true = batch
|
|
||||||
# masks_true = masks_true.unsqueeze(1)
|
|
||||||
# masks_pred = self(images)
|
|
||||||
# masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
|
|
||||||
# # compute metrics
|
|
||||||
# loss = F.cross_entropy(masks_pred, masks_true)
|
|
||||||
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
|
||||||
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
|
||||||
# dice = dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
# if batch_idx == 0:
|
|
||||||
# self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
|
||||||
|
|
||||||
# return loss, dice, accuracy, mae
|
|
||||||
|
|
||||||
# def test_step_end(self, test_outputs):
|
|
||||||
# # unpacking
|
|
||||||
# list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
|
||||||
|
|
||||||
# # averaging
|
|
||||||
# loss = np.mean(list_loss)
|
|
||||||
# dice = np.mean(list_dice)
|
|
||||||
# accuracy = np.mean(list_accuracy)
|
|
||||||
# mae = np.mean(list_mae)
|
|
||||||
|
|
||||||
# # # get learning rate
|
|
||||||
# # optimizer = self.optimizers[0]
|
|
||||||
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
|
||||||
|
|
||||||
# wandb.log(
|
|
||||||
# {
|
|
||||||
# # "train/learning_rate": learning_rate,
|
|
||||||
# "test/accuracy": accuracy,
|
|
||||||
# "test/bce": loss,
|
|
||||||
# "test/dice": dice,
|
|
||||||
# "test/mae": mae,
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.RMSprop(
|
optimizer = torch.optim.RMSprop(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
|
|
|
@ -1,80 +1,18 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
def dice_score(inputs, targets, smooth=1, logits=True):
|
||||||
"""Average of Dice coefficient for all batches, or for a single mask.
|
# comment out if your model contains a sigmoid or equivalent activation layer
|
||||||
|
if logits:
|
||||||
|
inputs = torch.sigmoid(inputs)
|
||||||
|
|
||||||
Args:
|
# flatten label and prediction tensors
|
||||||
input (Tensor): _description_
|
inputs = inputs.view(-1)
|
||||||
target (Tensor): _description_
|
targets = targets.view(-1)
|
||||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
|
||||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
|
||||||
|
|
||||||
Raises:
|
intersection = (inputs * targets).sum()
|
||||||
ValueError: _description_
|
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: _description_
|
|
||||||
"""
|
|
||||||
assert input.size() == target.size()
|
|
||||||
|
|
||||||
if input.dim() == 2 and reduce_batch_first:
|
|
||||||
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
|
||||||
|
|
||||||
if input.dim() == 2 or reduce_batch_first:
|
|
||||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
|
||||||
sets_sum = torch.sum(input) + torch.sum(target)
|
|
||||||
|
|
||||||
if sets_sum.item() == 0:
|
|
||||||
sets_sum = 2 * inter
|
|
||||||
|
|
||||||
return (2 * inter + epsilon) / (sets_sum + epsilon)
|
|
||||||
else:
|
|
||||||
# compute and average metric for each batch element
|
|
||||||
dice = 0
|
|
||||||
|
|
||||||
for i in range(input.shape[0]):
|
|
||||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
|
||||||
|
|
||||||
return dice / input.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
def dice_loss(inputs, targets, smooth=1, logits=True):
|
||||||
"""Average of Dice coefficient for all classes.
|
return 1 - dice_score(inputs, targets, smooth, logits)
|
||||||
|
|
||||||
Args:
|
|
||||||
input (Tensor): _description_
|
|
||||||
target (Tensor): _description_
|
|
||||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
|
||||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: _description_
|
|
||||||
"""
|
|
||||||
assert input.size() == target.size()
|
|
||||||
|
|
||||||
dice = 0
|
|
||||||
|
|
||||||
for channel in range(input.shape[1]):
|
|
||||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
|
||||||
|
|
||||||
return dice / input.shape[1]
|
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
|
|
||||||
"""Dice loss (objective to minimize) between 0 and 1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (Tensor): _description_
|
|
||||||
target (Tensor): _description_
|
|
||||||
multiclass (bool, optional): _description_. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: _description_
|
|
||||||
"""
|
|
||||||
assert input.size() == target.size()
|
|
||||||
|
|
||||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
|
||||||
|
|
||||||
return 1 - fn(input, target, reduce_batch_first=True)
|
|
||||||
|
|
|
@ -93,6 +93,18 @@ class RandomPaste(A.DualTransform):
|
||||||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||||
|
|
||||||
|
# change paste_img's brightness randomly
|
||||||
|
filter = ImageEnhance.Brightness(paste_img)
|
||||||
|
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||||
|
|
||||||
|
# change paste_img's contrast randomly
|
||||||
|
filter = ImageEnhance.Contrast(paste_img)
|
||||||
|
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||||
|
|
||||||
|
# change paste_img's sharpness randomly
|
||||||
|
filter = ImageEnhance.Sharpness(paste_img)
|
||||||
|
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||||
|
|
||||||
# compute the minimum scaling to fit inside target image
|
# compute the minimum scaling to fit inside target image
|
||||||
min_scale = np.min(target_shape / paste_shape)
|
min_scale = np.min(target_shape / paste_shape)
|
||||||
|
|
||||||
|
@ -117,10 +129,6 @@ class RandomPaste(A.DualTransform):
|
||||||
# update paste_shape after scaling
|
# update paste_shape after scaling
|
||||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||||
|
|
||||||
# change brightness randomly
|
|
||||||
filter = ImageEnhance.Brightness(paste_img)
|
|
||||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
|
||||||
|
|
||||||
# generate some positions
|
# generate some positions
|
||||||
positions = []
|
positions = []
|
||||||
NB = rd.randint(1, self.nb)
|
NB = rd.randint(1, self.nb)
|
||||||
|
|
Loading…
Reference in a new issue