Now using utils.data.Dataset

Former-commit-id: c75d9c075e18add5cd8683faf827937393bf2c94
This commit is contained in:
milesial 2019-11-23 14:22:42 +01:00
parent f5c2771242
commit 35f955cbf8
6 changed files with 105 additions and 145 deletions

28
eval.py
View file

@ -5,27 +5,25 @@ from tqdm import tqdm
from dice_loss import dice_coeff
def eval_net(net, dataset, device, n_val):
def eval_net(net, loader, device, n_val):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
tot = 0
for i, b in tqdm(enumerate(dataset), total=n_val, desc='Validation round', unit='img'):
img = b[0]
true_mask = b[1]
for i, b in tqdm(enumerate(loader), desc='Validation round', unit='img'):
imgs = b['image']
true_masks = b['mask']
img = torch.from_numpy(img).unsqueeze(0)
true_mask = torch.from_numpy(true_mask).unsqueeze(0)
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.float32)
img = img.to(device=device)
true_mask = true_mask.to(device=device)
mask_pred = net(imgs)
mask_pred = net(img).squeeze(dim=0)
mask_pred = (mask_pred > 0.5).float()
if net.n_classes > 1:
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else:
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
for true_mask in true_masks:
mask_pred = (mask_pred > 0.5).float()
if net.n_classes > 1:
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else:
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
return tot / n_val

View file

@ -10,8 +10,7 @@ import torch.nn.functional as F
from unet import UNet
from utils import plot_img_and_mask
from utils import resize_and_crop, normalize, hwc_to_chw, dense_crf
from utils.dataset import BasicDataset
def predict_img(net,
full_img,
@ -20,18 +19,15 @@ def predict_img(net,
out_threshold=0.5,
use_dense_crf=False):
net.eval()
img_height = full_img.size[1]
img = resize_and_crop(full_img, scale=scale_factor)
img = normalize(img)
img = hwc_to_chw(img)
ds = BasicDataset('', '', scale=scale_factor)
img = ds.preprocess(full_img)
X = torch.from_numpy(img).unsqueeze(0)
X = X.to(device=device)
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(X)
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
@ -43,13 +39,12 @@ def predict_img(net,
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(img_height),
transforms.Resize(full_img.shape[1]),
transforms.ToTensor()
]
)
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
if use_dense_crf:

View file

@ -13,6 +13,9 @@ from eval import eval_net
from unet import UNet
from utils import get_ids, split_train_val, get_imgs_and_masks, batch
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/imgs/'
dir_mask = 'data/masks/'
dir_checkpoint = 'checkpoints/'
@ -26,23 +29,25 @@ def train_net(net,
val_percent=0.15,
save_cp=True,
img_scale=0.5):
ids = get_ids(dir_img)
iddataset = split_train_val(ids, val_percent)
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {len(iddataset["train"])}
Validation size: {len(iddataset["val"])}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
n_train = len(iddataset['train'])
n_val = len(iddataset['val'])
optimizer = optim.Adam(net.parameters(), lr=lr)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
@ -52,21 +57,23 @@ def train_net(net,
for epoch in range(epochs):
net.train()
# reset the generators
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for i, b in enumerate(batch(train, batch_size)):
imgs = np.array([i[0] for i in b]).astype(np.float32)
true_masks = np.array([i[1] for i in b])
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = torch.from_numpy(imgs)
true_masks = torch.from_numpy(true_masks)
assert true_masks.shape[1] == net.n_classes, \
f'Network has been defined with {net.n_classes} output classes, ' \
f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \
'the masks are loaded correctly.'
imgs = imgs.to(device=device)
true_masks = true_masks.to(device=device)
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.float32)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
@ -90,7 +97,7 @@ def train_net(net,
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
val_score = eval_net(net, val, device, n_val)
val_score = eval_net(net, val_loader, device, n_val)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
@ -117,18 +124,9 @@ def get_args():
return parser.parse_args()
def pretrain_checks():
imgs = [f for f in os.listdir(dir_img) if not f.startswith('.')]
masks = [f for f in os.listdir(dir_mask) if not f.startswith('.')]
if len(imgs) != len(masks):
logging.warning(f'The number of images and masks do not match ! '
f'{len(imgs)} images and {len(masks)} masks detected in the data folder.')
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
pretrain_checks()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

60
utils/dataset.py Normal file
View file

@ -0,0 +1,60 @@
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
def preprocess(self, pil_img):
w, h = pil_img.size
newW, newH = int(self.scale * w), int(self.scale * h)
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, i):
idx = self.ids[i]
mask_file = glob(self.masks_dir + idx + '*')
img_file = glob(self.imgs_dir + idx + '*')
assert len(mask_file) == 1, \
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
assert len(img_file) == 1, \
f'Either no image or multiple images found for the ID {idx}: {img_file}'
mask = Image.open(mask_file[0])
img = Image.open(img_file[0])
assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img)
mask = self.preprocess(mask)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}

View file

@ -1,40 +0,0 @@
""" Utils on generators / lists of ids to transform from strings to cropped images and masks """
import os
import numpy as np
from PIL import Image
from .utils import resize_and_crop, normalize, hwc_to_chw
def get_ids(dir):
"""Returns a list of the ids in the directory"""
return (os.path.splitext(f)[0] for f in os.listdir(dir) if not f.startswith('.'))
def to_cropped_imgs(ids, dir, suffix, scale):
"""From a list of tuples, returns the correct cropped img"""
for id in ids:
im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
yield im
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
"""Return all the couples (img, mask)"""
imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
# need to transform from HWC to CHW
imgs_switched = map(hwc_to_chw, imgs)
imgs_normalized = map(normalize, imgs_switched)
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
masks_switched = map(hwc_to_chw, masks)
return zip(imgs_normalized, masks_switched)
def get_full_img_and_mask(id, dir_img, dir_mask):
im = Image.open(dir_img + id + '.jpg')
mask = Image.open(dir_mask + id + '_mask.gif')
return np.array(im), np.array(mask)

View file

@ -1,56 +1,5 @@
import random
import numpy as np
def hwc_to_chw(img):
return np.transpose(img, axes=[2, 0, 1])
def resize_and_crop(pilimg, scale=0.5, final_height=None):
w = pilimg.size[0]
h = pilimg.size[1]
newW = int(w * scale)
newH = int(h * scale)
if not final_height:
diff = 0
else:
diff = newH - final_height
img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
ar = np.array(img, dtype=np.float32)
if len(ar.shape) == 2:
# for greyscale images, add a new axis
ar = np.expand_dims(ar, axis=2)
return ar
def batch(iterable, batch_size):
"""Yields lists by batch"""
b = []
for i, t in enumerate(iterable):
b.append(t)
if (i + 1) % batch_size == 0:
yield b
b = []
if len(b) > 0:
yield b
def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
n = int(length * val_percent)
random.shuffle(dataset)
return {'train': dataset[:-n], 'val': dataset[-n:]}
def normalize(x):
return x / 255
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):
pixels = mask_image.flatten()