Removed dense_crf and small fixes
Former-commit-id: de7507ff08510b48e6a0e11da849e0d1c94d3ac8
This commit is contained in:
parent
5f4ce7dba9
commit
012fca4715
12
predict.py
12
predict.py
|
@ -4,22 +4,20 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import torch.nn.functional as F
|
||||
|
||||
from unet import UNet
|
||||
from utils.data_vis import plot_img_and_mask
|
||||
from utils.dataset import BasicDataset
|
||||
from utils.crf import dense_crf
|
||||
|
||||
|
||||
def predict_img(net,
|
||||
full_img,
|
||||
device,
|
||||
scale_factor=1,
|
||||
out_threshold=0.5,
|
||||
use_dense_crf=False):
|
||||
out_threshold=0.5):
|
||||
net.eval()
|
||||
|
||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
|
||||
|
@ -40,7 +38,7 @@ def predict_img(net,
|
|||
tf = transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(full_img.shape[1]),
|
||||
transforms.Resize(full_img.size[1]),
|
||||
transforms.ToTensor()
|
||||
]
|
||||
)
|
||||
|
@ -48,9 +46,6 @@ def predict_img(net,
|
|||
probs = tf(probs.cpu())
|
||||
full_mask = probs.squeeze().cpu().numpy()
|
||||
|
||||
if use_dense_crf:
|
||||
full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
|
||||
|
||||
return full_mask > out_threshold
|
||||
|
||||
|
||||
|
@ -127,7 +122,6 @@ if __name__ == "__main__":
|
|||
full_img=img,
|
||||
scale_factor=args.scale,
|
||||
out_threshold=args.mask_threshold,
|
||||
use_dense_crf=False,
|
||||
device=device)
|
||||
|
||||
if not args.no_save:
|
||||
|
|
28
train.py
28
train.py
|
@ -162,18 +162,18 @@ if __name__ == '__main__':
|
|||
# faster convolutions, but more memory
|
||||
# cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
train_net(net=net,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batchsize,
|
||||
lr=args.lr,
|
||||
device=device,
|
||||
img_scale=args.scale,
|
||||
val_percent=args.val / 100)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||
logging.info('Saved interrupt')
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0)
|
||||
train_net(net=net,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batchsize,
|
||||
lr=args.lr,
|
||||
device=device,
|
||||
img_scale=args.scale,
|
||||
val_percent=args.val / 100)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||
logging.info('Saved interrupt')
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0)
|
||||
|
|
Loading…
Reference in a new issue