diff --git a/dice_loss.py b/dice_loss.py index fa2e987..631638f 100644 --- a/dice_loss.py +++ b/dice_loss.py @@ -20,7 +20,7 @@ class DiceCoeff(Function): grad_input = grad_target = None if self.needs_input_grad[0]: - grad_input = grad_output * 2 * (target * self.union + self.inter) \ + grad_input = grad_output * 2 * (target * self.union - self.inter) \ / self.union * self.union if self.needs_input_grad[1]: grad_target = None