Switch to reshape in dice coeff

Former-commit-id: db72295019a2114f4c84940d9aaf1232b2a23352
This commit is contained in:
milesial 2021-08-31 17:54:02 +02:00 committed by GitHub
parent 449965b398
commit c912d9b726

View file

@ -9,7 +9,7 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
if input.dim() == 2 or reduce_batch_first: if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.view(-1), target.view(-1)) inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target) sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0: if sets_sum.item() == 0:
sets_sum = 2 * inter sets_sum = 2 * inter