diff --git a/predict.py b/predict.py index f8a2894..b26f655 100755 --- a/predict.py +++ b/predict.py @@ -6,6 +6,7 @@ import numpy as np import torch from PIL import Image from torchvision import transforms +import torch.nn.functional as F from unet import UNet from utils import plot_img_and_mask @@ -20,7 +21,6 @@ def predict_img(net, use_dense_crf=False): net.eval() img_height = full_img.size[1] - img_width = full_img.size[0] img = resize_and_crop(full_img, scale=scale_factor) img = normalize(img) @@ -32,7 +32,13 @@ def predict_img(net, with torch.no_grad(): output = net(X) - probs = output.squeeze(0) + + if net.n_classes > 1: + probs = F.softmax(output, dim=1) + else: + probs = torch.sigmoid(output) + + probs = probs.squeeze(0) tf = transforms.Compose( [ diff --git a/unet/unet_model.py b/unet/unet_model.py index 5b23b55..ee1f1bb 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -35,8 +35,3 @@ class UNet(nn.Module): x = self.up4(x, x1) logits = self.outc(x) return logits - - if self.n_classes > 1: - return F.softmax(x, dim=1) - else: - return torch.sigmoid(x)