Fixed misuse of epsilon in Dice loss (issue #24)

Former-commit-id: 1e8093de0398364cc65f3c7722b61a0f560df104
This commit is contained in:
milesial 2018-08-16 19:50:24 +02:00 committed by GitHub
parent f991ffa23e
commit 3fefc25199

View file

@ -6,10 +6,11 @@ class DiceCoeff(Function):
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001
eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps
t = 2 * self.inter.float() / self.union.float()
t = (2 * self.inter.float() + eps) / self.union.float()
return t
# This function has only a single output, so it gets only one gradient