mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Cleaned code, added image to README.md
Former-commit-id: 3acf1ff8dadb74e95786fb6ddcf1a90de63f5079
This commit is contained in:
parent
8b614c3e31
commit
2d427db832
10
README.md
10
README.md
|
@ -1,5 +1,8 @@
|
|||
# Pytorch-UNet
|
||||
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge).
|
||||
![input and output for a random image in the test dataset](https://framapic.org/YqBT4lbLrcfc/kQcSxYDv1Pfk.png)
|
||||
|
||||
|
||||
Customized implementation of the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) in Pytorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge), with only 1 output class, from a high definition image.
|
||||
|
||||
This model scored a [dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 (511 out of 735), which is bad but could be improved with more training, data augmentation, fine tuning, and playing with CRF post-processing.
|
||||
|
||||
|
@ -7,7 +10,7 @@ The model used for the last submission is stored in the `MODEL.pth` file, if you
|
|||
|
||||
## Usage
|
||||
|
||||
### Prediction
|
||||
### Prediction
|
||||
|
||||
You can easily test the output masks on your images via the CLI.
|
||||
To see all options:
|
||||
|
@ -27,3 +30,6 @@ You can specify which model file to use with `--model MODEL.pth`.
|
|||
|
||||
## Warning
|
||||
In order to process the image, it is splitted into two squares (a left on and a right one), and each square is passed into the net. The two square masks are then merged again to produce the final image. As a consequence, the height of the image must be strictly superior than half the width. Make sure the width is even too.
|
||||
|
||||
## Dependencies
|
||||
This package depends on [pydensecrf](https://github.com/lucasb-eyer/pydensecrf), available via `pip install`.
|
||||
|
|
4
eval.py
4
eval.py
|
@ -2,10 +2,10 @@ import torch
|
|||
from myloss import dice_coeff
|
||||
import numpy as np
|
||||
from torch.autograd import Variable
|
||||
from data_vis import plot_img_mask
|
||||
import matplotlib.pyplot as plt
|
||||
import torch.nn.functional as F
|
||||
from crf import dense_crf
|
||||
|
||||
from utils import dense_crf, plot_img_mask
|
||||
|
||||
|
||||
def eval_net(net, dataset, gpu=False):
|
||||
|
|
1
main.py
1
main.py
|
@ -14,7 +14,6 @@ import PIL
|
|||
import os
|
||||
|
||||
#data visualization
|
||||
from data_vis import plot_img_mask
|
||||
from utils import *
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ import argparse
|
|||
import os
|
||||
|
||||
from utils import *
|
||||
from crf import dense_crf
|
||||
|
||||
from unet import UNet
|
||||
|
||||
|
|
5
train.py
5
train.py
|
@ -3,9 +3,7 @@ import torch.backends.cudnn as cudnn
|
|||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
from load import *
|
||||
from data_vis import *
|
||||
from utils import split_train_val, batch
|
||||
from utils import *
|
||||
from myloss import DiceLoss
|
||||
from eval import eval_net
|
||||
from unet import UNet
|
||||
|
@ -14,7 +12,6 @@ from torch import optim
|
|||
from optparse import OptionParser
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
|
||||
|
|
4
utils/__init__.py
Normal file
4
utils/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .crf import *
|
||||
from .load import *
|
||||
from .utils import *
|
||||
from .data_vis import *
|
|
@ -15,7 +15,6 @@ def dense_crf(img, output_probs):
|
|||
U = np.ascontiguousarray(U)
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
|
||||
d.setUnaryEnergy(U)
|
||||
|
||||
d.addPairwiseGaussian(sxy=20, compat=3)
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
|
||||
from PIL import Image
|
||||
from functools import partial
|
||||
from utils import resize_and_crop, get_square, normalize
|
||||
from .utils import resize_and_crop, get_square, normalize
|
||||
|
||||
|
||||
def get_ids(dir):
|
Loading…
Reference in a new issue