Now using utils.data.Dataset
Former-commit-id: c75d9c075e18add5cd8683faf827937393bf2c94
This commit is contained in:
parent
f5c2771242
commit
35f955cbf8
28
eval.py
28
eval.py
|
@ -5,27 +5,25 @@ from tqdm import tqdm
|
|||
from dice_loss import dice_coeff
|
||||
|
||||
|
||||
def eval_net(net, dataset, device, n_val):
|
||||
def eval_net(net, loader, device, n_val):
|
||||
"""Evaluation without the densecrf with the dice coefficient"""
|
||||
net.eval()
|
||||
tot = 0
|
||||
|
||||
for i, b in tqdm(enumerate(dataset), total=n_val, desc='Validation round', unit='img'):
|
||||
img = b[0]
|
||||
true_mask = b[1]
|
||||
for i, b in tqdm(enumerate(loader), desc='Validation round', unit='img'):
|
||||
imgs = b['image']
|
||||
true_masks = b['mask']
|
||||
|
||||
img = torch.from_numpy(img).unsqueeze(0)
|
||||
true_mask = torch.from_numpy(true_mask).unsqueeze(0)
|
||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
||||
|
||||
img = img.to(device=device)
|
||||
true_mask = true_mask.to(device=device)
|
||||
mask_pred = net(imgs)
|
||||
|
||||
mask_pred = net(img).squeeze(dim=0)
|
||||
|
||||
mask_pred = (mask_pred > 0.5).float()
|
||||
if net.n_classes > 1:
|
||||
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
|
||||
else:
|
||||
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
|
||||
for true_mask in true_masks:
|
||||
mask_pred = (mask_pred > 0.5).float()
|
||||
if net.n_classes > 1:
|
||||
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
|
||||
else:
|
||||
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
|
||||
|
||||
return tot / n_val
|
||||
|
|
19
predict.py
19
predict.py
|
@ -10,8 +10,7 @@ import torch.nn.functional as F
|
|||
|
||||
from unet import UNet
|
||||
from utils import plot_img_and_mask
|
||||
from utils import resize_and_crop, normalize, hwc_to_chw, dense_crf
|
||||
|
||||
from utils.dataset import BasicDataset
|
||||
|
||||
def predict_img(net,
|
||||
full_img,
|
||||
|
@ -20,18 +19,15 @@ def predict_img(net,
|
|||
out_threshold=0.5,
|
||||
use_dense_crf=False):
|
||||
net.eval()
|
||||
img_height = full_img.size[1]
|
||||
|
||||
img = resize_and_crop(full_img, scale=scale_factor)
|
||||
img = normalize(img)
|
||||
img = hwc_to_chw(img)
|
||||
ds = BasicDataset('', '', scale=scale_factor)
|
||||
img = ds.preprocess(full_img)
|
||||
|
||||
X = torch.from_numpy(img).unsqueeze(0)
|
||||
|
||||
X = X.to(device=device)
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device=device, dtype=torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
output = net(X)
|
||||
output = net(img)
|
||||
|
||||
if net.n_classes > 1:
|
||||
probs = F.softmax(output, dim=1)
|
||||
|
@ -43,13 +39,12 @@ def predict_img(net,
|
|||
tf = transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(img_height),
|
||||
transforms.Resize(full_img.shape[1]),
|
||||
transforms.ToTensor()
|
||||
]
|
||||
)
|
||||
|
||||
probs = tf(probs.cpu())
|
||||
|
||||
full_mask = probs.squeeze().cpu().numpy()
|
||||
|
||||
if use_dense_crf:
|
||||
|
|
52
train.py
52
train.py
|
@ -13,6 +13,9 @@ from eval import eval_net
|
|||
from unet import UNet
|
||||
from utils import get_ids, split_train_val, get_imgs_and_masks, batch
|
||||
|
||||
from utils.dataset import BasicDataset
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
dir_img = 'data/imgs/'
|
||||
dir_mask = 'data/masks/'
|
||||
dir_checkpoint = 'checkpoints/'
|
||||
|
@ -26,23 +29,25 @@ def train_net(net,
|
|||
val_percent=0.15,
|
||||
save_cp=True,
|
||||
img_scale=0.5):
|
||||
ids = get_ids(dir_img)
|
||||
|
||||
iddataset = split_train_val(ids, val_percent)
|
||||
dataset = BasicDataset(dir_img, dir_mask, img_scale)
|
||||
n_val = int(len(dataset) * val_percent)
|
||||
n_train = len(dataset) - n_val
|
||||
train, val = random_split(dataset, [n_train, n_val])
|
||||
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
logging.info(f'''Starting training:
|
||||
Epochs: {epochs}
|
||||
Batch size: {batch_size}
|
||||
Learning rate: {lr}
|
||||
Training size: {len(iddataset["train"])}
|
||||
Validation size: {len(iddataset["val"])}
|
||||
Training size: {n_train}
|
||||
Validation size: {n_val}
|
||||
Checkpoints: {save_cp}
|
||||
Device: {device.type}
|
||||
Images scaling: {img_scale}
|
||||
''')
|
||||
|
||||
n_train = len(iddataset['train'])
|
||||
n_val = len(iddataset['val'])
|
||||
optimizer = optim.Adam(net.parameters(), lr=lr)
|
||||
if net.n_classes > 1:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
@ -52,21 +57,23 @@ def train_net(net,
|
|||
for epoch in range(epochs):
|
||||
net.train()
|
||||
|
||||
# reset the generators
|
||||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
|
||||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
|
||||
|
||||
epoch_loss = 0
|
||||
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
|
||||
for i, b in enumerate(batch(train, batch_size)):
|
||||
imgs = np.array([i[0] for i in b]).astype(np.float32)
|
||||
true_masks = np.array([i[1] for i in b])
|
||||
for batch in train_loader:
|
||||
imgs = batch['image']
|
||||
true_masks = batch['mask']
|
||||
assert imgs.shape[1] == net.n_channels, \
|
||||
f'Network has been defined with {net.n_channels} input channels, ' \
|
||||
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
|
||||
'the images are loaded correctly.'
|
||||
|
||||
imgs = torch.from_numpy(imgs)
|
||||
true_masks = torch.from_numpy(true_masks)
|
||||
assert true_masks.shape[1] == net.n_classes, \
|
||||
f'Network has been defined with {net.n_classes} output classes, ' \
|
||||
f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \
|
||||
'the masks are loaded correctly.'
|
||||
|
||||
imgs = imgs.to(device=device)
|
||||
true_masks = true_masks.to(device=device)
|
||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
||||
|
||||
masks_pred = net(imgs)
|
||||
loss = criterion(masks_pred, true_masks)
|
||||
|
@ -90,7 +97,7 @@ def train_net(net,
|
|||
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
|
||||
logging.info(f'Checkpoint {epoch + 1} saved !')
|
||||
|
||||
val_score = eval_net(net, val, device, n_val)
|
||||
val_score = eval_net(net, val_loader, device, n_val)
|
||||
if net.n_classes > 1:
|
||||
logging.info('Validation cross entropy: {}'.format(val_score))
|
||||
|
||||
|
@ -117,18 +124,9 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def pretrain_checks():
|
||||
imgs = [f for f in os.listdir(dir_img) if not f.startswith('.')]
|
||||
masks = [f for f in os.listdir(dir_mask) if not f.startswith('.')]
|
||||
if len(imgs) != len(masks):
|
||||
logging.warning(f'The number of images and masks do not match ! '
|
||||
f'{len(imgs)} images and {len(masks)} masks detected in the data folder.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
args = get_args()
|
||||
pretrain_checks()
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logging.info(f'Using device {device}')
|
||||
|
||||
|
|
60
utils/dataset.py
Normal file
60
utils/dataset.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from os.path import splitext
|
||||
from os import listdir
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import logging
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class BasicDataset(Dataset):
|
||||
def __init__(self, imgs_dir, masks_dir, scale=1):
|
||||
self.imgs_dir = imgs_dir
|
||||
self.masks_dir = masks_dir
|
||||
self.scale = scale
|
||||
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
||||
|
||||
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
|
||||
if not file.startswith('.')]
|
||||
logging.info(f'Creating dataset with {len(self.ids)} examples')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
def preprocess(self, pil_img):
|
||||
w, h = pil_img.size
|
||||
newW, newH = int(self.scale * w), int(self.scale * h)
|
||||
pil_img = pil_img.resize((newW, newH))
|
||||
|
||||
img_nd = np.array(pil_img)
|
||||
|
||||
if len(img_nd.shape) == 2:
|
||||
img_nd = np.expand_dims(img_nd, axis=2)
|
||||
|
||||
# HWC to CHW
|
||||
img_trans = img_nd.transpose((2, 0, 1))
|
||||
if img_trans.max() > 1:
|
||||
img_trans = img_trans / 255
|
||||
|
||||
return img_trans
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = self.ids[i]
|
||||
mask_file = glob(self.masks_dir + idx + '*')
|
||||
img_file = glob(self.imgs_dir + idx + '*')
|
||||
|
||||
assert len(mask_file) == 1, \
|
||||
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
|
||||
assert len(img_file) == 1, \
|
||||
f'Either no image or multiple images found for the ID {idx}: {img_file}'
|
||||
mask = Image.open(mask_file[0])
|
||||
img = Image.open(img_file[0])
|
||||
|
||||
assert img.size == mask.size, \
|
||||
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
|
||||
|
||||
img = self.preprocess(img)
|
||||
mask = self.preprocess(mask)
|
||||
|
||||
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
|
|
@ -1,40 +0,0 @@
|
|||
""" Utils on generators / lists of ids to transform from strings to cropped images and masks """
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .utils import resize_and_crop, normalize, hwc_to_chw
|
||||
|
||||
|
||||
def get_ids(dir):
|
||||
"""Returns a list of the ids in the directory"""
|
||||
return (os.path.splitext(f)[0] for f in os.listdir(dir) if not f.startswith('.'))
|
||||
|
||||
|
||||
def to_cropped_imgs(ids, dir, suffix, scale):
|
||||
"""From a list of tuples, returns the correct cropped img"""
|
||||
for id in ids:
|
||||
im = resize_and_crop(Image.open(dir + id + suffix), scale=scale)
|
||||
yield im
|
||||
|
||||
|
||||
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
|
||||
"""Return all the couples (img, mask)"""
|
||||
imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
|
||||
|
||||
# need to transform from HWC to CHW
|
||||
imgs_switched = map(hwc_to_chw, imgs)
|
||||
imgs_normalized = map(normalize, imgs_switched)
|
||||
|
||||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale)
|
||||
masks_switched = map(hwc_to_chw, masks)
|
||||
|
||||
return zip(imgs_normalized, masks_switched)
|
||||
|
||||
|
||||
def get_full_img_and_mask(id, dir_img, dir_mask):
|
||||
im = Image.open(dir_img + id + '.jpg')
|
||||
mask = Image.open(dir_mask + id + '_mask.gif')
|
||||
return np.array(im), np.array(mask)
|
|
@ -1,56 +1,5 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def hwc_to_chw(img):
|
||||
return np.transpose(img, axes=[2, 0, 1])
|
||||
|
||||
|
||||
def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
||||
w = pilimg.size[0]
|
||||
h = pilimg.size[1]
|
||||
newW = int(w * scale)
|
||||
newH = int(h * scale)
|
||||
|
||||
if not final_height:
|
||||
diff = 0
|
||||
else:
|
||||
diff = newH - final_height
|
||||
|
||||
img = pilimg.resize((newW, newH))
|
||||
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
||||
ar = np.array(img, dtype=np.float32)
|
||||
if len(ar.shape) == 2:
|
||||
# for greyscale images, add a new axis
|
||||
ar = np.expand_dims(ar, axis=2)
|
||||
return ar
|
||||
|
||||
def batch(iterable, batch_size):
|
||||
"""Yields lists by batch"""
|
||||
b = []
|
||||
for i, t in enumerate(iterable):
|
||||
b.append(t)
|
||||
if (i + 1) % batch_size == 0:
|
||||
yield b
|
||||
b = []
|
||||
|
||||
if len(b) > 0:
|
||||
yield b
|
||||
|
||||
|
||||
def split_train_val(dataset, val_percent=0.05):
|
||||
dataset = list(dataset)
|
||||
length = len(dataset)
|
||||
n = int(length * val_percent)
|
||||
random.shuffle(dataset)
|
||||
return {'train': dataset[:-n], 'val': dataset[-n:]}
|
||||
|
||||
|
||||
def normalize(x):
|
||||
return x / 255
|
||||
|
||||
|
||||
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
||||
def rle_encode(mask_image):
|
||||
pixels = mask_image.flatten()
|
||||
|
|
Loading…
Reference in a new issue