diff --git a/utils/utils.py b/utils/utils.py index 1d48738..c1eb578 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -9,7 +9,7 @@ def plot_img_and_mask(img, mask): if classes > 1: for i in range(classes): ax[i + 1].set_title(f'Output mask (class {i + 1})') - ax[i + 1].imshow(mask[:, :, i]) + ax[i + 1].imshow(mask[1, :, :]) else: ax[1].set_title(f'Output mask') ax[1].imshow(mask)