mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
54 lines
1.4 KiB
Python
54 lines
1.4 KiB
Python
|
|
#
|
|
# myloss.py : implementation of the Dice coeff and the associated loss
|
|
#
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.autograd import Function, Variable
|
|
|
|
|
|
class DiceCoeff(Function):
|
|
"""Dice coeff for individual examples"""
|
|
def forward(self, input, target):
|
|
self.save_for_backward(input, target)
|
|
self.inter = torch.dot(input, target) + 0.0001
|
|
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
|
|
|
t = 2*self.inter.float()/self.union.float()
|
|
return t
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
|
def backward(self, grad_output):
|
|
|
|
input, target = self.saved_variables
|
|
grad_input = grad_target = None
|
|
|
|
if self.needs_input_grad[0]:
|
|
grad_input = grad_output * 2 * (target * self.union + self.inter) \
|
|
/ self.union * self.union
|
|
if self.needs_input_grad[1]:
|
|
grad_target = None
|
|
|
|
return grad_input, grad_target
|
|
|
|
|
|
def dice_coeff(input, target):
|
|
"""Dice coeff for batches"""
|
|
if input.is_cuda:
|
|
s = Variable(torch.FloatTensor(1).cuda().zero_())
|
|
else:
|
|
s = Variable(torch.FloatTensor(1).zero_())
|
|
|
|
for i, c in enumerate(zip(input, target)):
|
|
s = s + DiceCoeff().forward(c[0], c[1])
|
|
|
|
return s / (i+1)
|
|
|
|
|
|
class DiceLoss(_Loss):
|
|
def forward(self, input, target):
|
|
return 1 - dice_coeff(F.sigmoid(input), target)
|