diff --git a/eval.py b/eval.py index e48d522..e261532 100644 --- a/eval.py +++ b/eval.py @@ -5,27 +5,25 @@ from tqdm import tqdm from dice_loss import dice_coeff -def eval_net(net, dataset, device, n_val): +def eval_net(net, loader, device, n_val): """Evaluation without the densecrf with the dice coefficient""" net.eval() tot = 0 - for i, b in tqdm(enumerate(dataset), total=n_val, desc='Validation round', unit='img'): - img = b[0] - true_mask = b[1] + for i, b in tqdm(enumerate(loader), desc='Validation round', unit='img'): + imgs = b['image'] + true_masks = b['mask'] - img = torch.from_numpy(img).unsqueeze(0) - true_mask = torch.from_numpy(true_mask).unsqueeze(0) + imgs = imgs.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.float32) - img = img.to(device=device) - true_mask = true_mask.to(device=device) + mask_pred = net(imgs) - mask_pred = net(img).squeeze(dim=0) - - mask_pred = (mask_pred > 0.5).float() - if net.n_classes > 1: - tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() - else: - tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + for true_mask in true_masks: + mask_pred = (mask_pred > 0.5).float() + if net.n_classes > 1: + tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() + else: + tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() return tot / n_val diff --git a/predict.py b/predict.py index b26f655..dcb879e 100755 --- a/predict.py +++ b/predict.py @@ -10,8 +10,7 @@ import torch.nn.functional as F from unet import UNet from utils import plot_img_and_mask -from utils import resize_and_crop, normalize, hwc_to_chw, dense_crf - +from utils.dataset import BasicDataset def predict_img(net, full_img, @@ -20,18 +19,15 @@ def predict_img(net, out_threshold=0.5, use_dense_crf=False): net.eval() - img_height = full_img.size[1] - img = resize_and_crop(full_img, scale=scale_factor) - img = normalize(img) - img = hwc_to_chw(img) + ds = BasicDataset('', '', scale=scale_factor) + img = ds.preprocess(full_img) - X = torch.from_numpy(img).unsqueeze(0) - - X = X.to(device=device) + img = img.unsqueeze(0) + img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): - output = net(X) + output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1) @@ -43,13 +39,12 @@ def predict_img(net, tf = transforms.Compose( [ transforms.ToPILImage(), - transforms.Resize(img_height), + transforms.Resize(full_img.shape[1]), transforms.ToTensor() ] ) probs = tf(probs.cpu()) - full_mask = probs.squeeze().cpu().numpy() if use_dense_crf: diff --git a/train.py b/train.py index d875b64..ebc80f6 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,9 @@ from eval import eval_net from unet import UNet from utils import get_ids, split_train_val, get_imgs_and_masks, batch +from utils.dataset import BasicDataset +from torch.utils.data import DataLoader, random_split + dir_img = 'data/imgs/' dir_mask = 'data/masks/' dir_checkpoint = 'checkpoints/' @@ -26,23 +29,25 @@ def train_net(net, val_percent=0.15, save_cp=True, img_scale=0.5): - ids = get_ids(dir_img) - iddataset = split_train_val(ids, val_percent) + dataset = BasicDataset(dir_img, dir_mask, img_scale) + n_val = int(len(dataset) * val_percent) + n_train = len(dataset) - n_val + train, val = random_split(dataset, [n_train, n_val]) + train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4) logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} - Training size: {len(iddataset["train"])} - Validation size: {len(iddataset["val"])} + Training size: {n_train} + Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') - n_train = len(iddataset['train']) - n_val = len(iddataset['val']) optimizer = optim.Adam(net.parameters(), lr=lr) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() @@ -52,21 +57,23 @@ def train_net(net, for epoch in range(epochs): net.train() - # reset the generators - train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale) - val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale) - epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: - for i, b in enumerate(batch(train, batch_size)): - imgs = np.array([i[0] for i in b]).astype(np.float32) - true_masks = np.array([i[1] for i in b]) + for batch in train_loader: + imgs = batch['image'] + true_masks = batch['mask'] + assert imgs.shape[1] == net.n_channels, \ + f'Network has been defined with {net.n_channels} input channels, ' \ + f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ + 'the images are loaded correctly.' - imgs = torch.from_numpy(imgs) - true_masks = torch.from_numpy(true_masks) + assert true_masks.shape[1] == net.n_classes, \ + f'Network has been defined with {net.n_classes} output classes, ' \ + f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \ + 'the masks are loaded correctly.' - imgs = imgs.to(device=device) - true_masks = true_masks.to(device=device) + imgs = imgs.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.float32) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) @@ -90,7 +97,7 @@ def train_net(net, dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') - val_score = eval_net(net, val, device, n_val) + val_score = eval_net(net, val_loader, device, n_val) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) @@ -117,18 +124,9 @@ def get_args(): return parser.parse_args() -def pretrain_checks(): - imgs = [f for f in os.listdir(dir_img) if not f.startswith('.')] - masks = [f for f in os.listdir(dir_mask) if not f.startswith('.')] - if len(imgs) != len(masks): - logging.warning(f'The number of images and masks do not match ! ' - f'{len(imgs)} images and {len(masks)} masks detected in the data folder.') - - if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') args = get_args() - pretrain_checks() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Using device {device}') diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..69d922a --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,60 @@ +from os.path import splitext +from os import listdir +import numpy as np +from glob import glob +import torch +from torch.utils.data import Dataset +import logging +from PIL import Image + + +class BasicDataset(Dataset): + def __init__(self, imgs_dir, masks_dir, scale=1): + self.imgs_dir = imgs_dir + self.masks_dir = masks_dir + self.scale = scale + assert 0 < scale <= 1, 'Scale must be between 0 and 1' + + self.ids = [splitext(file)[0] for file in listdir(imgs_dir) + if not file.startswith('.')] + logging.info(f'Creating dataset with {len(self.ids)} examples') + + def __len__(self): + return len(self.ids) + + def preprocess(self, pil_img): + w, h = pil_img.size + newW, newH = int(self.scale * w), int(self.scale * h) + pil_img = pil_img.resize((newW, newH)) + + img_nd = np.array(pil_img) + + if len(img_nd.shape) == 2: + img_nd = np.expand_dims(img_nd, axis=2) + + # HWC to CHW + img_trans = img_nd.transpose((2, 0, 1)) + if img_trans.max() > 1: + img_trans = img_trans / 255 + + return img_trans + + def __getitem__(self, i): + idx = self.ids[i] + mask_file = glob(self.masks_dir + idx + '*') + img_file = glob(self.imgs_dir + idx + '*') + + assert len(mask_file) == 1, \ + f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' + assert len(img_file) == 1, \ + f'Either no image or multiple images found for the ID {idx}: {img_file}' + mask = Image.open(mask_file[0]) + img = Image.open(img_file[0]) + + assert img.size == mask.size, \ + f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' + + img = self.preprocess(img) + mask = self.preprocess(mask) + + return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} diff --git a/utils/load.py b/utils/load.py deleted file mode 100644 index 306fd54..0000000 --- a/utils/load.py +++ /dev/null @@ -1,40 +0,0 @@ -""" Utils on generators / lists of ids to transform from strings to cropped images and masks """ - -import os - -import numpy as np -from PIL import Image - -from .utils import resize_and_crop, normalize, hwc_to_chw - - -def get_ids(dir): - """Returns a list of the ids in the directory""" - return (os.path.splitext(f)[0] for f in os.listdir(dir) if not f.startswith('.')) - - -def to_cropped_imgs(ids, dir, suffix, scale): - """From a list of tuples, returns the correct cropped img""" - for id in ids: - im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) - yield im - - -def get_imgs_and_masks(ids, dir_img, dir_mask, scale): - """Return all the couples (img, mask)""" - imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale) - - # need to transform from HWC to CHW - imgs_switched = map(hwc_to_chw, imgs) - imgs_normalized = map(normalize, imgs_switched) - - masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale) - masks_switched = map(hwc_to_chw, masks) - - return zip(imgs_normalized, masks_switched) - - -def get_full_img_and_mask(id, dir_img, dir_mask): - im = Image.open(dir_img + id + '.jpg') - mask = Image.open(dir_mask + id + '_mask.gif') - return np.array(im), np.array(mask) diff --git a/utils/utils.py b/utils/utils.py index 9066edd..eb2f855 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,56 +1,5 @@ -import random - import numpy as np - -def hwc_to_chw(img): - return np.transpose(img, axes=[2, 0, 1]) - - -def resize_and_crop(pilimg, scale=0.5, final_height=None): - w = pilimg.size[0] - h = pilimg.size[1] - newW = int(w * scale) - newH = int(h * scale) - - if not final_height: - diff = 0 - else: - diff = newH - final_height - - img = pilimg.resize((newW, newH)) - img = img.crop((0, diff // 2, newW, newH - diff // 2)) - ar = np.array(img, dtype=np.float32) - if len(ar.shape) == 2: - # for greyscale images, add a new axis - ar = np.expand_dims(ar, axis=2) - return ar - -def batch(iterable, batch_size): - """Yields lists by batch""" - b = [] - for i, t in enumerate(iterable): - b.append(t) - if (i + 1) % batch_size == 0: - yield b - b = [] - - if len(b) > 0: - yield b - - -def split_train_val(dataset, val_percent=0.05): - dataset = list(dataset) - length = len(dataset) - n = int(length * val_percent) - random.shuffle(dataset) - return {'train': dataset[:-n], 'val': dataset[-n:]} - - -def normalize(x): - return x / 255 - - # credits to https://stackoverflow.com/users/6076729/manuel-lagunas def rle_encode(mask_image): pixels = mask_image.flatten()