From 012fca4715114a48be53891574aa1d35a9930f95 Mon Sep 17 00:00:00 2001 From: milesial Date: Sat, 21 Dec 2019 22:04:23 +0100 Subject: [PATCH] Removed dense_crf and small fixes Former-commit-id: de7507ff08510b48e6a0e11da849e0d1c94d3ac8 --- predict.py | 12 +++--------- train.py | 28 ++++++++++++++-------------- utils/dataset.py | 2 +- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/predict.py b/predict.py index 95112b0..fd12ed0 100755 --- a/predict.py +++ b/predict.py @@ -4,22 +4,20 @@ import os import numpy as np import torch +import torch.nn.functional as F from PIL import Image from torchvision import transforms -import torch.nn.functional as F from unet import UNet from utils.data_vis import plot_img_and_mask from utils.dataset import BasicDataset -from utils.crf import dense_crf def predict_img(net, full_img, device, scale_factor=1, - out_threshold=0.5, - use_dense_crf=False): + out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) @@ -40,7 +38,7 @@ def predict_img(net, tf = transforms.Compose( [ transforms.ToPILImage(), - transforms.Resize(full_img.shape[1]), + transforms.Resize(full_img.size[1]), transforms.ToTensor() ] ) @@ -48,9 +46,6 @@ def predict_img(net, probs = tf(probs.cpu()) full_mask = probs.squeeze().cpu().numpy() - if use_dense_crf: - full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) - return full_mask > out_threshold @@ -127,7 +122,6 @@ if __name__ == "__main__": full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, - use_dense_crf=False, device=device) if not args.no_save: diff --git a/train.py b/train.py index c23aff2..d52c8b6 100644 --- a/train.py +++ b/train.py @@ -162,18 +162,18 @@ if __name__ == '__main__': # faster convolutions, but more memory # cudnn.benchmark = True -try: - train_net(net=net, - epochs=args.epochs, - batch_size=args.batchsize, - lr=args.lr, - device=device, - img_scale=args.scale, - val_percent=args.val / 100) -except KeyboardInterrupt: - torch.save(net.state_dict(), 'INTERRUPTED.pth') - logging.info('Saved interrupt') try: - sys.exit(0) - except SystemExit: - os._exit(0) + train_net(net=net, + epochs=args.epochs, + batch_size=args.batchsize, + lr=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100) + except KeyboardInterrupt: + torch.save(net.state_dict(), 'INTERRUPTED.pth') + logging.info('Saved interrupt') + try: + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/utils/dataset.py b/utils/dataset.py index c290eeb..45d31b7 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -15,7 +15,7 @@ class BasicDataset(Dataset): self.scale = scale assert 0 < scale <= 1, 'Scale must be between 0 and 1' - self.ids = [splitext(file)[0] for file in listdir(imgs_dir) + self.ids = [splitext(file)[0] for file in listdir(imgs_dir) if not file.startswith('.')] logging.info(f'Creating dataset with {len(self.ids)} examples')