Removed dense_crf and small fixes

Former-commit-id: de7507ff08510b48e6a0e11da849e0d1c94d3ac8
This commit is contained in:
milesial 2019-12-21 22:04:23 +01:00
parent 5f4ce7dba9
commit 012fca4715
3 changed files with 18 additions and 24 deletions

View file

@ -4,22 +4,20 @@ import os
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
import torch.nn.functional as F
from unet import UNet from unet import UNet
from utils.data_vis import plot_img_and_mask from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset from utils.dataset import BasicDataset
from utils.crf import dense_crf
def predict_img(net, def predict_img(net,
full_img, full_img,
device, device,
scale_factor=1, scale_factor=1,
out_threshold=0.5, out_threshold=0.5):
use_dense_crf=False):
net.eval() net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
@ -40,7 +38,7 @@ def predict_img(net,
tf = transforms.Compose( tf = transforms.Compose(
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Resize(full_img.shape[1]), transforms.Resize(full_img.size[1]),
transforms.ToTensor() transforms.ToTensor()
] ]
) )
@ -48,9 +46,6 @@ def predict_img(net,
probs = tf(probs.cpu()) probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy() full_mask = probs.squeeze().cpu().numpy()
if use_dense_crf:
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
return full_mask > out_threshold return full_mask > out_threshold
@ -127,7 +122,6 @@ if __name__ == "__main__":
full_img=img, full_img=img,
scale_factor=args.scale, scale_factor=args.scale,
out_threshold=args.mask_threshold, out_threshold=args.mask_threshold,
use_dense_crf=False,
device=device) device=device)
if not args.no_save: if not args.no_save:

View file

@ -162,18 +162,18 @@ if __name__ == '__main__':
# faster convolutions, but more memory # faster convolutions, but more memory
# cudnn.benchmark = True # cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try: try:
sys.exit(0) train_net(net=net,
except SystemExit: epochs=args.epochs,
os._exit(0) batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)