diff --git a/train.py b/train.py index 146373a..cae9a86 100644 --- a/train.py +++ b/train.py @@ -130,7 +130,7 @@ def train_net(net, 'images': wandb.Image(images[0].cpu()), 'masks': { 'true': wandb.Image(true_masks[0].float().cpu()), - 'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()), + 'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()), }, 'step': global_step, 'epoch': epoch,