diff --git a/dice_loss.py b/dice_loss.py index 631638f..71edf6a 100644 --- a/dice_loss.py +++ b/dice_loss.py @@ -21,7 +21,7 @@ class DiceCoeff(Function): if self.needs_input_grad[0]: grad_input = grad_output * 2 * (target * self.union - self.inter) \ - / self.union * self.union + / (self.union * self.union) if self.needs_input_grad[1]: grad_target = None