diff --git a/dice_loss.py b/dice_loss.py index 29a287d..fa2e987 100644 --- a/dice_loss.py +++ b/dice_loss.py @@ -6,10 +6,11 @@ class DiceCoeff(Function): def forward(self, input, target): self.save_for_backward(input, target) - self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001 - self.union = torch.sum(input) + torch.sum(target) + 0.0001 + eps = 0.0001 + self.inter = torch.dot(input.view(-1), target.view(-1)) + self.union = torch.sum(input) + torch.sum(target) + eps - t = 2 * self.inter.float() / self.union.float() + t = (2 * self.inter.float() + eps) / self.union.float() return t # This function has only a single output, so it gets only one gradient