Fixed misuse of epsilon in Dice loss (issue #24)

Former-commit-id: 1e8093de0398364cc65f3c7722b61a0f560df104
This commit is contained in:
milesial 2018-08-16 19:50:24 +02:00 committed by GitHub
parent f991ffa23e
commit 3fefc25199

View file

@ -6,10 +6,11 @@ class DiceCoeff(Function):
def forward(self, input, target): def forward(self, input, target):
self.save_for_backward(input, target) self.save_for_backward(input, target)
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001 eps = 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001 self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps
t = 2 * self.inter.float() / self.union.float() t = (2 * self.inter.float() + eps) / self.union.float()
return t return t
# This function has only a single output, so it gets only one gradient # This function has only a single output, so it gets only one gradient