mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Switch to reshape in dice coeff
Former-commit-id: db72295019a2114f4c84940d9aaf1232b2a23352
This commit is contained in:
parent
449965b398
commit
c912d9b726
|
@ -9,7 +9,7 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
|
|||
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
|
||||
|
||||
if input.dim() == 2 or reduce_batch_first:
|
||||
inter = torch.dot(input.view(-1), target.view(-1))
|
||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
||||
sets_sum = torch.sum(input) + torch.sum(target)
|
||||
if sets_sum.item() == 0:
|
||||
sets_sum = 2 * inter
|
||||
|
|
Loading…
Reference in a new issue