Added CLI for predict, cleaned up code, updated README
Former-commit-id: 77555ccc0925a8fba796ce7e42843d95b6e9dce0
This commit is contained in:
parent
e1bf150da3
commit
7ea54febec
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -3,4 +3,5 @@ data/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
checkpoints/
|
checkpoints/
|
||||||
*.pth
|
*.pth
|
||||||
|
*.jpg
|
||||||
|
SUBMISSION*
|
||||||
|
|
14
README.md
14
README.md
|
@ -5,7 +5,21 @@ This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rens
|
||||||
|
|
||||||
The model used for the last submission is stored in the `MODEL.pth` file, if you wish to play with it. The data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
The model used for the last submission is stored in the `MODEL.pth` file, if you wish to play with it. The data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
### Prediction
|
||||||
|
|
||||||
|
You can easily test the output masks on your images via the CLI.
|
||||||
|
To see all options:
|
||||||
|
`python predict.py -h`
|
||||||
|
|
||||||
|
To predict a single image and save it:
|
||||||
|
`python predict.py -i image.jpg -o ouput.jpg
|
||||||
|
|
||||||
|
To predict a multiple images and show them without saving them:
|
||||||
|
`python predict.py -i image1.jpg image2.jpg --viz --no-save`
|
||||||
|
|
||||||
|
You can use the cpu-only version with `--cpu`.
|
||||||
|
You can specify which model file to use with `--model MODEL.pth`.
|
||||||
|
|
||||||
## Note
|
## Note
|
||||||
The code and the overall project architecture is a big mess for now, as I left it abandoned when the challenge finished. I will clean it Soon<sup>TM</sup>.
|
The code and the overall project architecture is a big mess for now, as I left it abandoned when the challenge finished. I will clean it Soon<sup>TM</sup>.
|
||||||
|
|
2
main.py
2
main.py
|
@ -1,5 +1,5 @@
|
||||||
#models
|
#models
|
||||||
from unet_model import UNet
|
from unet import UNet
|
||||||
from myloss import *
|
from myloss import *
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
|
83
predict.py
83
predict.py
|
@ -1,12 +1,16 @@
|
||||||
import torch
|
import torch
|
||||||
from utils import *
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
|
||||||
from unet_model import UNet
|
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy
|
||||||
|
from PIL import Image
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from utils import *
|
||||||
from crf import dense_crf
|
from crf import dense_crf
|
||||||
|
|
||||||
|
from unet import UNet
|
||||||
|
|
||||||
def predict_img(net, full_img, gpu=False):
|
def predict_img(net, full_img, gpu=False):
|
||||||
img = resize_and_crop(full_img)
|
img = resize_and_crop(full_img)
|
||||||
|
@ -39,3 +43,76 @@ def predict_img(net, full_img, gpu=False):
|
||||||
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
|
yy = dense_crf(np.array(full_img).astype(np.uint8), y)
|
||||||
|
|
||||||
return yy > 0.5
|
return yy > 0.5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--model', '-m', default='MODEL.pth',
|
||||||
|
metavar='FILE',
|
||||||
|
help="Specify the file in which is stored the model"
|
||||||
|
" (default : 'MODEL.pth')")
|
||||||
|
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
||||||
|
help='filenames of input images', required=True)
|
||||||
|
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||||
|
help='filenames of ouput images')
|
||||||
|
parser.add_argument('--cpu', '-c', action='store_true',
|
||||||
|
help="Do not use the cuda version of the net",
|
||||||
|
default=False)
|
||||||
|
parser.add_argument('--viz', '-v', action='store_true',
|
||||||
|
help="Visualize the images as they are processed",
|
||||||
|
default=False)
|
||||||
|
parser.add_argument('--no-save', '-n', action='store_false',
|
||||||
|
help="Do not save the output masks",
|
||||||
|
default=False)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("Using model file : {}".format(args.model))
|
||||||
|
net = UNet(3, 1)
|
||||||
|
if not args.cpu:
|
||||||
|
print("Using CUDA version of the net, prepare your GPU !")
|
||||||
|
net.cuda()
|
||||||
|
else:
|
||||||
|
net.cpu()
|
||||||
|
print("Using CPU version of the net, this may be very slow")
|
||||||
|
|
||||||
|
in_files = args.input
|
||||||
|
out_files = []
|
||||||
|
if not args.output:
|
||||||
|
for f in in_files:
|
||||||
|
pathsplit = os.path.splitext(f)
|
||||||
|
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
|
||||||
|
elif len(in_files) != len(args.output):
|
||||||
|
print("Error : Input files and output files are not of the same length")
|
||||||
|
raise SystemExit()
|
||||||
|
else:
|
||||||
|
out_files = args.output
|
||||||
|
|
||||||
|
print("Loading model ...")
|
||||||
|
net.load_state_dict(torch.load(args.model))
|
||||||
|
print("Model loaded !")
|
||||||
|
|
||||||
|
for i, fn in enumerate(in_files):
|
||||||
|
print("\nPredicting image {} ...".format(fn))
|
||||||
|
img = Image.open(fn)
|
||||||
|
out = predict_img(net, img, not args.cpu)
|
||||||
|
|
||||||
|
if args.viz:
|
||||||
|
print("Vizualising results for image {}, close to continue ..."
|
||||||
|
.format(fn))
|
||||||
|
|
||||||
|
fig = plt.figure()
|
||||||
|
a = fig.add_subplot(1, 2, 1)
|
||||||
|
a.set_title('Input image')
|
||||||
|
plt.imshow(img)
|
||||||
|
|
||||||
|
b = fig.add_subplot(1, 2, 2)
|
||||||
|
b.set_title('Output mask')
|
||||||
|
plt.imshow(out)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if not args.no_save:
|
||||||
|
out_fn = out_files[i]
|
||||||
|
result = Image.fromarray((out * 255).astype(numpy.uint8))
|
||||||
|
result.save(out_files[i])
|
||||||
|
print("Mask saved to {}".format(out_files[i]))
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from predict import *
|
from predict import *
|
||||||
from utils import encode
|
from utils import encode
|
||||||
from unet_model import UNet
|
from unet import UNet
|
||||||
|
|
||||||
def submit(net, gpu=False):
|
def submit(net, gpu=False):
|
||||||
dir = 'data/test/'
|
dir = 'data/test/'
|
||||||
|
|
3
train.py
3
train.py
|
@ -8,12 +8,13 @@ from data_vis import *
|
||||||
from utils import split_train_val, batch
|
from utils import split_train_val, batch
|
||||||
from myloss import DiceLoss
|
from myloss import DiceLoss
|
||||||
from eval import eval_net
|
from eval import eval_net
|
||||||
from unet_model import UNet
|
from unet import UNet
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from optparse import OptionParser
|
from optparse import OptionParser
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||||
|
|
1
unet/__init__.py
Normal file
1
unet/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .unet_model import UNet
|
|
@ -1,8 +1,12 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
# full assembly of the sub-parts to form the complete net
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from unet_parts import *
|
# python 3 confusing imports :(
|
||||||
|
from .unet_parts import *
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
|
@ -1,3 +1,5 @@
|
||||||
|
#!/usr/bin/python
|
||||||
|
|
||||||
# sub-parts of the U-Net model
|
# sub-parts of the U-Net model
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -6,6 +8,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class double_conv(nn.Module):
|
class double_conv(nn.Module):
|
||||||
|
'''(conv => BN => ReLU) * 2'''
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch):
|
||||||
super(double_conv, self).__init__()
|
super(double_conv, self).__init__()
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
|
@ -46,10 +49,16 @@ class down(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class up(nn.Module):
|
class up(nn.Module):
|
||||||
def __init__(self, in_ch, out_ch):
|
def __init__(self, in_ch, out_ch, bilinear=True):
|
||||||
super(up, self).__init__()
|
super(up, self).__init__()
|
||||||
|
|
||||||
|
# would be a nice idea if the upsampling could be learned too,
|
||||||
|
# but my machine do not have enough memory to handle all those weights
|
||||||
|
if bilinear:
|
||||||
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||||
# self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
else:
|
||||||
|
self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
|
||||||
|
|
||||||
self.conv = double_conv(in_ch, out_ch)
|
self.conv = double_conv(in_ch, out_ch)
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
def forward(self, x1, x2):
|
4
utils.py
4
utils.py
|
@ -119,3 +119,7 @@ def rle_encode(mask_image):
|
||||||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
||||||
runs[1::2] = runs[1::2] - runs[:-1:2]
|
runs[1::2] = runs[1::2] - runs[:-1:2]
|
||||||
return runs
|
return runs
|
||||||
|
|
||||||
|
def full_process(filename):
|
||||||
|
im = PIL.Image.open(filename)
|
||||||
|
im = resize_and_crop(im)
|
||||||
|
|
Loading…
Reference in a new issue