mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Fixed misuse of epsilon in Dice loss (issue #24)
Former-commit-id: 1e8093de0398364cc65f3c7722b61a0f560df104
This commit is contained in:
parent
f991ffa23e
commit
3fefc25199
|
@ -6,10 +6,11 @@ class DiceCoeff(Function):
|
|||
|
||||
def forward(self, input, target):
|
||||
self.save_for_backward(input, target)
|
||||
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
|
||||
self.union = torch.sum(input) + torch.sum(target) + 0.0001
|
||||
eps = 0.0001
|
||||
self.inter = torch.dot(input.view(-1), target.view(-1))
|
||||
self.union = torch.sum(input) + torch.sum(target) + eps
|
||||
|
||||
t = 2 * self.inter.float() / self.union.float()
|
||||
t = (2 * self.inter.float() + eps) / self.union.float()
|
||||
return t
|
||||
|
||||
# This function has only a single output, so it gets only one gradient
|
||||
|
|
Loading…
Reference in a new issue