mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
3fefc25199
Former-commit-id: 1e8093de0398364cc65f3c7722b61a0f560df104
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import torch
|
|
from torch.autograd import Function, Variable
|
|
|
|
class DiceCoeff(Function):
|
|
"""Dice coeff for individual examples"""
|
|
|
|
def forward(self, input, target):
|
|
self.save_for_backward(input, target)
|
|
eps = 0.0001
|
|
self.inter = torch.dot(input.view(-1), target.view(-1))
|
|
self.union = torch.sum(input) + torch.sum(target) + eps
|
|
|
|
t = (2 * self.inter.float() + eps) / self.union.float()
|
|
return t
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
|
def backward(self, grad_output):
|
|
|
|
input, target = self.saved_variables
|
|
grad_input = grad_target = None
|
|
|
|
if self.needs_input_grad[0]:
|
|
grad_input = grad_output * 2 * (target * self.union + self.inter) \
|
|
/ self.union * self.union
|
|
if self.needs_input_grad[1]:
|
|
grad_target = None
|
|
|
|
return grad_input, grad_target
|
|
|
|
|
|
def dice_coeff(input, target):
|
|
"""Dice coeff for batches"""
|
|
if input.is_cuda:
|
|
s = torch.FloatTensor(1).cuda().zero_()
|
|
else:
|
|
s = torch.FloatTensor(1).zero_()
|
|
|
|
for i, c in enumerate(zip(input, target)):
|
|
s = s + DiceCoeff().forward(c[0], c[1])
|
|
|
|
return s / (i + 1)
|