mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Switch to reshape in dice coeff
Former-commit-id: db72295019a2114f4c84940d9aaf1232b2a23352
This commit is contained in:
parent
449965b398
commit
c912d9b726
|
@ -9,7 +9,7 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
|
||||||
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
|
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
|
||||||
|
|
||||||
if input.dim() == 2 or reduce_batch_first:
|
if input.dim() == 2 or reduce_batch_first:
|
||||||
inter = torch.dot(input.view(-1), target.view(-1))
|
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
||||||
sets_sum = torch.sum(input) + torch.sum(target)
|
sets_sum = torch.sum(input) + torch.sum(target)
|
||||||
if sets_sum.item() == 0:
|
if sets_sum.item() == 0:
|
||||||
sets_sum = 2 * inter
|
sets_sum = 2 * inter
|
||||||
|
|
Loading…
Reference in a new issue