2017-08-16 12:24:29 +00:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
|
2019-10-24 19:37:21 +00:00
|
|
|
def plot_img_and_mask(img, mask):
|
|
|
|
classes = mask.shape[2] if len(mask.shape) > 2 else 1
|
|
|
|
fig, ax = plt.subplots(1, classes + 1)
|
|
|
|
ax[0].set_title('Input image')
|
|
|
|
ax[0].imshow(img)
|
|
|
|
if classes > 1:
|
|
|
|
for i in range(classes):
|
|
|
|
ax[i+1].set_title(f'Output mask (class {i+1})')
|
|
|
|
ax[i+1].imshow(mask[:, :, i])
|
|
|
|
else:
|
|
|
|
ax[1].set_title(f'Output mask')
|
|
|
|
ax[1].imshow(mask)
|
|
|
|
plt.xticks([]), plt.yticks([])
|
|
|
|
plt.show()
|