Removed unused function and general cleanup

Former-commit-id: c34a455f1722e0b899e9e92c7766b83a9a641980
This commit is contained in:
milesial 2018-04-09 05:15:24 +02:00
parent 0da4ad7753
commit 8008b77af6
7 changed files with 39 additions and 101 deletions

10
eval.py
View file

@ -1,11 +1,11 @@
import torch
from myloss import dice_coeff
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from utils import dense_crf, plot_img_mask
from myloss import dice_coeff
from utils import dense_crf
def eval_net(net, dataset, gpu=False):

View file

@ -1,17 +1,14 @@
#
# myloss.py : implementation of the Dice coeff and the associated loss
#
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001
@ -46,8 +43,3 @@ def dice_coeff(input, target):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1)
class DiceLoss(_Loss):
def forward(self, input, target):
return 1 - dice_coeff(F.sigmoid(input), target)

View file

@ -1,15 +1,12 @@
import argparse
import numpy
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy
from PIL import Image
import argparse
import os
from utils import *
from unet import UNet
from utils import *
def predict_img(net, full_img, gpu=False):

View file

@ -1,17 +1,14 @@
# used to predict all test images and encode results in a csv file
import os
from PIL import Image
from predict import *
from utils import encode
from unet import UNet
def submit(net, gpu=False):
dir = 'data/test/'
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))

View file

@ -1,17 +1,16 @@
import sys
from optparse import OptionParser
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from utils import *
from myloss import DiceLoss
from eval import eval_net
from unet import UNet
from torch.autograd import Variable
from torch import optim
from optparse import OptionParser
import sys
import os
from utils import *
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
@ -39,15 +38,14 @@ def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
N_train = len(iddataset['train'])
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)
optimizer = optim.SGD(net.parameters(),
lr=lr, momentum=0.9, weight_decay=0.0005)
criterion = nn.BCELoss()
for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
# reset the generators
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)

View file

@ -1,13 +1,13 @@
#
# load.py : utils on generators / lists of ids to transform from strings to
# cropped images and masks
import os
import numpy as np
from PIL import Image
from functools import partial
import numpy as np
from PIL import Image
from .utils import resize_and_crop, get_square, normalize
@ -41,6 +41,7 @@ def get_imgs_and_masks(ids, dir_img, dir_mask):
return zip(imgs_normalized, masks)
def get_full_img_and_mask(id, dir_img, dir_mask):
im = Image.open(dir_img + id + '.jpg')
mask = Image.open(dir_mask + id + '_mask.gif')

View file

@ -1,7 +1,7 @@
import PIL
import numpy as np
import random
import numpy as np
def get_square(img, pos):
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
@ -46,7 +46,6 @@ def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
n = int(length * val_percent)
random.seed(42)
random.shuffle(dataset)
return {'train': dataset[:-n], 'val': dataset[-n:]}
@ -56,58 +55,16 @@ def normalize(x):
def merge_masks(img1, img2, full_w):
w = img1.shape[1]
overlap = int(2 * w - full_w)
h = img1.shape[0]
new = np.zeros((h, full_w), np.float32)
margin = 0
new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
#new[:, w-overlap+1+margin//2:-(w-overlap+margin//2)] = (img1[:, -overlap+margin:] +
# img2[:, :overlap-margin])/2
return new
import matplotlib.pyplot as plt
def encode(mask):
"""mask : HxW"""
plt.imshow(mask.transpose())
plt.show()
flat = mask.transpose().reshape(-1)
enc = []
i = 1
while i <= len(flat):
if(flat[i-1]):
s = i
while(flat[i-1]):
i += 1
e = i-1
enc.append(s)
enc.append(e - s + 1)
i += 1
plt.imshow(decode(enc))
plt.show()
return enc
def decode(list):
mask = np.zeros((1280*1920), np.bool)
for i, e in enumerate(list):
if(i%2 == 0):
mask[e-1:e-2+list[i+1]] = True
mask = mask.reshape(1920, 1280).transpose()
return mask
def rle_encode(mask_image):
pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of
@ -119,7 +76,3 @@ def rle_encode(mask_image):
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
runs[1::2] = runs[1::2] - runs[:-1:2]
return runs
def full_process(filename):
im = PIL.Image.open(filename)
im = resize_and_crop(im)