Fix predict script (mistake from last commit)
Former-commit-id: 9fd386f5929945f3df1868930d392507fee16a55
This commit is contained in:
parent
8ed1e09b2a
commit
f5c2771242
10
predict.py
10
predict.py
|
@ -6,6 +6,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils import plot_img_and_mask
|
from utils import plot_img_and_mask
|
||||||
|
@ -20,7 +21,6 @@ def predict_img(net,
|
||||||
use_dense_crf=False):
|
use_dense_crf=False):
|
||||||
net.eval()
|
net.eval()
|
||||||
img_height = full_img.size[1]
|
img_height = full_img.size[1]
|
||||||
img_width = full_img.size[0]
|
|
||||||
|
|
||||||
img = resize_and_crop(full_img, scale=scale_factor)
|
img = resize_and_crop(full_img, scale=scale_factor)
|
||||||
img = normalize(img)
|
img = normalize(img)
|
||||||
|
@ -32,7 +32,13 @@ def predict_img(net,
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = net(X)
|
output = net(X)
|
||||||
probs = output.squeeze(0)
|
|
||||||
|
if net.n_classes > 1:
|
||||||
|
probs = F.softmax(output, dim=1)
|
||||||
|
else:
|
||||||
|
probs = torch.sigmoid(output)
|
||||||
|
|
||||||
|
probs = probs.squeeze(0)
|
||||||
|
|
||||||
tf = transforms.Compose(
|
tf = transforms.Compose(
|
||||||
[
|
[
|
||||||
|
|
|
@ -35,8 +35,3 @@ class UNet(nn.Module):
|
||||||
x = self.up4(x, x1)
|
x = self.up4(x, x1)
|
||||||
logits = self.outc(x)
|
logits = self.outc(x)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
if self.n_classes > 1:
|
|
||||||
return F.softmax(x, dim=1)
|
|
||||||
else:
|
|
||||||
return torch.sigmoid(x)
|
|
||||||
|
|
Loading…
Reference in a new issue