Fix predict script (mistake from last commit)
Former-commit-id: 9fd386f5929945f3df1868930d392507fee16a55
This commit is contained in:
parent
8ed1e09b2a
commit
f5c2771242
10
predict.py
10
predict.py
|
@ -6,6 +6,7 @@ import numpy as np
|
|||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import torch.nn.functional as F
|
||||
|
||||
from unet import UNet
|
||||
from utils import plot_img_and_mask
|
||||
|
@ -20,7 +21,6 @@ def predict_img(net,
|
|||
use_dense_crf=False):
|
||||
net.eval()
|
||||
img_height = full_img.size[1]
|
||||
img_width = full_img.size[0]
|
||||
|
||||
img = resize_and_crop(full_img, scale=scale_factor)
|
||||
img = normalize(img)
|
||||
|
@ -32,7 +32,13 @@ def predict_img(net,
|
|||
|
||||
with torch.no_grad():
|
||||
output = net(X)
|
||||
probs = output.squeeze(0)
|
||||
|
||||
if net.n_classes > 1:
|
||||
probs = F.softmax(output, dim=1)
|
||||
else:
|
||||
probs = torch.sigmoid(output)
|
||||
|
||||
probs = probs.squeeze(0)
|
||||
|
||||
tf = transforms.Compose(
|
||||
[
|
||||
|
|
|
@ -35,8 +35,3 @@ class UNet(nn.Module):
|
|||
x = self.up4(x, x1)
|
||||
logits = self.outc(x)
|
||||
return logits
|
||||
|
||||
if self.n_classes > 1:
|
||||
return F.softmax(x, dim=1)
|
||||
else:
|
||||
return torch.sigmoid(x)
|
||||
|
|
Loading…
Reference in a new issue