Removed dense_crf and small fixes
Former-commit-id: de7507ff08510b48e6a0e11da849e0d1c94d3ac8
This commit is contained in:
parent
5f4ce7dba9
commit
012fca4715
12
predict.py
12
predict.py
|
@ -4,22 +4,20 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.data_vis import plot_img_and_mask
|
from utils.data_vis import plot_img_and_mask
|
||||||
from utils.dataset import BasicDataset
|
from utils.dataset import BasicDataset
|
||||||
from utils.crf import dense_crf
|
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net,
|
def predict_img(net,
|
||||||
full_img,
|
full_img,
|
||||||
device,
|
device,
|
||||||
scale_factor=1,
|
scale_factor=1,
|
||||||
out_threshold=0.5,
|
out_threshold=0.5):
|
||||||
use_dense_crf=False):
|
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
|
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
|
||||||
|
@ -40,7 +38,7 @@ def predict_img(net,
|
||||||
tf = transforms.Compose(
|
tf = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.ToPILImage(),
|
transforms.ToPILImage(),
|
||||||
transforms.Resize(full_img.shape[1]),
|
transforms.Resize(full_img.size[1]),
|
||||||
transforms.ToTensor()
|
transforms.ToTensor()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -48,9 +46,6 @@ def predict_img(net,
|
||||||
probs = tf(probs.cpu())
|
probs = tf(probs.cpu())
|
||||||
full_mask = probs.squeeze().cpu().numpy()
|
full_mask = probs.squeeze().cpu().numpy()
|
||||||
|
|
||||||
if use_dense_crf:
|
|
||||||
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
|
|
||||||
|
|
||||||
return full_mask > out_threshold
|
return full_mask > out_threshold
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +122,6 @@ if __name__ == "__main__":
|
||||||
full_img=img,
|
full_img=img,
|
||||||
scale_factor=args.scale,
|
scale_factor=args.scale,
|
||||||
out_threshold=args.mask_threshold,
|
out_threshold=args.mask_threshold,
|
||||||
use_dense_crf=False,
|
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
if not args.no_save:
|
if not args.no_save:
|
||||||
|
|
Loading…
Reference in a new issue