Merge pull request #318 from Liunaijiaaa/master

Fix wandb error in multi-category segmentation

Former-commit-id: e1a69e7c6ce18edd47271b01e4aabc03b436753d
This commit is contained in:
milesial 2021-11-13 10:28:59 +01:00 committed by GitHub
commit 473a03ce17

View file

@ -130,7 +130,7 @@ def train_net(net,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,