2017-08-16 12:24:29 +00:00
|
|
|
import torch
|
2019-10-24 19:37:21 +00:00
|
|
|
from torch.autograd import Function
|
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
class DiceCoeff(Function):
|
|
|
|
"""Dice coeff for individual examples"""
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
def forward(self, input, target):
|
|
|
|
self.save_for_backward(input, target)
|
2018-08-16 17:50:24 +00:00
|
|
|
eps = 0.0001
|
|
|
|
self.inter = torch.dot(input.view(-1), target.view(-1))
|
|
|
|
self.union = torch.sum(input) + torch.sum(target) + eps
|
2017-08-16 12:24:29 +00:00
|
|
|
|
2018-08-16 17:50:24 +00:00
|
|
|
t = (2 * self.inter.float() + eps) / self.union.float()
|
2017-08-16 12:24:29 +00:00
|
|
|
return t
|
|
|
|
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
2017-08-17 19:16:19 +00:00
|
|
|
def backward(self, grad_output):
|
2017-08-16 12:24:29 +00:00
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
input, target = self.saved_variables
|
2017-08-16 12:24:29 +00:00
|
|
|
grad_input = grad_target = None
|
|
|
|
|
|
|
|
if self.needs_input_grad[0]:
|
2018-10-26 16:53:05 +00:00
|
|
|
grad_input = grad_output * 2 * (target * self.union - self.inter) \
|
2018-10-28 13:54:14 +00:00
|
|
|
/ (self.union * self.union)
|
2017-08-16 12:24:29 +00:00
|
|
|
if self.needs_input_grad[1]:
|
|
|
|
grad_target = None
|
|
|
|
|
|
|
|
return grad_input, grad_target
|
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2017-08-16 12:24:29 +00:00
|
|
|
def dice_coeff(input, target):
|
2017-08-17 19:16:19 +00:00
|
|
|
"""Dice coeff for batches"""
|
|
|
|
if input.is_cuda:
|
2018-06-08 17:27:32 +00:00
|
|
|
s = torch.FloatTensor(1).cuda().zero_()
|
2017-08-17 19:16:19 +00:00
|
|
|
else:
|
2018-06-08 17:27:32 +00:00
|
|
|
s = torch.FloatTensor(1).zero_()
|
2017-08-17 19:16:19 +00:00
|
|
|
|
|
|
|
for i, c in enumerate(zip(input, target)):
|
|
|
|
s = s + DiceCoeff().forward(c[0], c[1])
|
|
|
|
|
2018-04-09 03:15:24 +00:00
|
|
|
return s / (i + 1)
|