diff --git a/utils/dice_score.py b/utils/dice_score.py index f69a286..c07f0d0 100644 --- a/utils/dice_score.py +++ b/utils/dice_score.py @@ -9,7 +9,7 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.view(-1), target.view(-1)) + inter = torch.dot(input.reshape(-1), target.reshape(-1)) sets_sum = torch.sum(input) + torch.sum(target) if sets_sum.item() == 0: sets_sum = 2 * inter