Fixed math error in DiceLoss backward pass (#35)
Former-commit-id: 649124b3bd26272bb8442843a546c4e895ec14ff
This commit is contained in:
parent
f2d88f98a7
commit
7e5a8727de
|
@ -20,7 +20,7 @@ class DiceCoeff(Function):
|
||||||
grad_input = grad_target = None
|
grad_input = grad_target = None
|
||||||
|
|
||||||
if self.needs_input_grad[0]:
|
if self.needs_input_grad[0]:
|
||||||
grad_input = grad_output * 2 * (target * self.union + self.inter) \
|
grad_input = grad_output * 2 * (target * self.union - self.inter) \
|
||||||
/ self.union * self.union
|
/ self.union * self.union
|
||||||
if self.needs_input_grad[1]:
|
if self.needs_input_grad[1]:
|
||||||
grad_target = None
|
grad_target = None
|
||||||
|
|
Loading…
Reference in a new issue