Fix predict script (mistake from last commit)

Former-commit-id: 9fd386f5929945f3df1868930d392507fee16a55
This commit is contained in:
milesial 2019-11-06 13:13:37 +01:00
parent 8ed1e09b2a
commit f5c2771242
2 changed files with 8 additions and 7 deletions

View file

@ -6,6 +6,7 @@ import numpy as np
import torch
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from unet import UNet
from utils import plot_img_and_mask
@ -20,7 +21,6 @@ def predict_img(net,
use_dense_crf=False):
net.eval()
img_height = full_img.size[1]
img_width = full_img.size[0]
img = resize_and_crop(full_img, scale=scale_factor)
img = normalize(img)
@ -32,7 +32,13 @@ def predict_img(net,
with torch.no_grad():
output = net(X)
probs = output.squeeze(0)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[

View file

@ -35,8 +35,3 @@ class UNet(nn.Module):
x = self.up4(x, x1)
logits = self.outc(x)
return logits
if self.n_classes > 1:
return F.softmax(x, dim=1)
else:
return torch.sigmoid(x)