Removed dense_crf and small fixes

Former-commit-id: de7507ff08510b48e6a0e11da849e0d1c94d3ac8
This commit is contained in:
milesial 2019-12-21 22:04:23 +01:00
parent 5f4ce7dba9
commit 012fca4715
3 changed files with 18 additions and 24 deletions

View file

@ -4,22 +4,20 @@ import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
from utils.crf import dense_crf
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5,
use_dense_crf=False):
out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
@ -40,7 +38,7 @@ def predict_img(net,
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.shape[1]),
transforms.Resize(full_img.size[1]),
transforms.ToTensor()
]
)
@ -48,9 +46,6 @@ def predict_img(net,
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
if use_dense_crf:
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
return full_mask > out_threshold
@ -127,7 +122,6 @@ if __name__ == "__main__":
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
use_dense_crf=False,
device=device)
if not args.no_save:

View file

@ -162,18 +162,18 @@ if __name__ == '__main__':
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)

View file

@ -15,7 +15,7 @@ class BasicDataset(Dataset):
self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')