diff --git a/.gitignore b/.gitignore index da7b4f9..6a01e6c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,5 @@ __pycache__/ checkpoints/ *.pth *.jpg -SUBMISSION* -venv/ \ No newline at end of file +venv/ +.idea/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c106ccc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,9 @@ +FROM nvcr.io/nvidia/pytorch:21.06-py3 + +RUN rm -rf /workspace/* +WORKDIR /workspace/unet + +ADD requirements.txt . +RUN pip install --no-cache-dir --upgrade --pre pip +RUN pip install --no-cache-dir -r requirements.txt +ADD . . diff --git a/data_loading.py b/data_loading.py new file mode 100644 index 0000000..e2e1dd7 --- /dev/null +++ b/data_loading.py @@ -0,0 +1,80 @@ +import logging +from os import listdir +from os.path import splitext +from pathlib import Path + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + + +class BasicDataset(Dataset): + def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''): + self.images_dir = Path(images_dir) + self.masks_dir = Path(masks_dir) + assert 0 < scale <= 1, 'Scale must be between 0 and 1' + self.scale = scale + self.mask_suffix = mask_suffix + + self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')] + if not self.ids: + raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') + logging.info(f'Creating dataset with {len(self.ids)} examples') + + def __len__(self): + return len(self.ids) + + @classmethod + def preprocess(cls, pil_img, scale, is_mask): + w, h = pil_img.size + newW, newH = int(scale * w), int(scale * h) + assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' + pil_img = pil_img.resize((newW, newH)) + img_ndarray = np.asarray(pil_img) + + if img_ndarray.ndim == 2 and not is_mask: + img_ndarray = img_ndarray[np.newaxis, ...] + elif not is_mask: + img_ndarray = img_ndarray.transpose((2, 0, 1)) + + if not is_mask: + img_ndarray = img_ndarray / 255 + + return img_ndarray + + @classmethod + def load(cls, filename): + ext = splitext(filename)[1] + if ext in ['.npz', '.npy']: + return Image.fromarray(np.load(filename)) + elif ext in ['.pt', '.pth']: + return Image.fromarray(torch.load(filename).numpy()) + else: + return Image.open(filename) + + def __getitem__(self, idx): + name = self.ids[idx] + mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*')) + img_file = list(self.images_dir.glob(name + '.*')) + + assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' + assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' + mask = self.load(mask_file[0]) + img = self.load(img_file[0]) + + assert img.size == mask.size, \ + 'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' + + img = self.preprocess(img, self.scale, is_mask=False) + mask = self.preprocess(mask, self.scale, is_mask=True) + + return { + 'image': torch.as_tensor(img.copy()).float().contiguous(), + 'mask': torch.as_tensor(mask.copy()).long().contiguous() + } + + +class CarvanaDataset(BasicDataset): + def __init__(self, images_dir, masks_dir, scale=1): + super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask') diff --git a/dice_loss.py b/dice_loss.py deleted file mode 100644 index fe86611..0000000 --- a/dice_loss.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from torch.autograd import Function - - -class DiceCoeff(Function): - """Dice coeff for individual examples""" - - def forward(self, input, target): - self.save_for_backward(input, target) - eps = 0.0001 - self.inter = torch.dot(input.view(-1), target.view(-1)) - self.union = torch.sum(input) + torch.sum(target) + eps - - t = (2 * self.inter.float() + eps) / self.union.float() - return t - - # This function has only a single output, so it gets only one gradient - def backward(self, grad_output): - - input, target = self.saved_variables - grad_input = grad_target = None - - if self.needs_input_grad[0]: - grad_input = grad_output * 2 * (target * self.union - self.inter) \ - / (self.union * self.union) - if self.needs_input_grad[1]: - grad_target = None - - return grad_input, grad_target - - -def dice_coeff(input, target): - """Dice coeff for batches""" - if input.is_cuda: - s = torch.FloatTensor(1).cuda().zero_() - else: - s = torch.FloatTensor(1).zero_() - - for i, c in enumerate(zip(input, target)): - s = s + DiceCoeff().forward(c[0], c[1]) - - return s / (i + 1) diff --git a/dice_score.py b/dice_score.py new file mode 100644 index 0000000..f69a286 --- /dev/null +++ b/dice_score.py @@ -0,0 +1,40 @@ +import torch +from torch import Tensor + + +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): + # Average of Dice coefficient for all batches, or for a single mask + assert input.size() == target.size() + if input.dim() == 2 and reduce_batch_first: + raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') + + if input.dim() == 2 or reduce_batch_first: + inter = torch.dot(input.view(-1), target.view(-1)) + sets_sum = torch.sum(input) + torch.sum(target) + if sets_sum.item() == 0: + sets_sum = 2 * inter + + return (2 * inter + epsilon) / (sets_sum + epsilon) + else: + # compute and average metric for each batch element + dice = 0 + for i in range(input.shape[0]): + dice += dice_coeff(input[i, ...], target[i, ...]) + return dice / input.shape[0] + + +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): + # Average of Dice coefficient for all classes + assert input.size() == target.size() + dice = 0 + for channel in range(input.shape[1]): + dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) + + return dice / input.shape[1] + + +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): + # Dice loss (objective to minimize) between 0 and 1 + assert input.size() == target.size() + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True) diff --git a/eval.py b/eval.py deleted file mode 100644 index d126c1e..0000000 --- a/eval.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.nn.functional as F -from tqdm import tqdm - -from dice_loss import dice_coeff - - -def eval_net(net, loader, device): - """Evaluation without the densecrf with the dice coefficient""" - net.eval() - mask_type = torch.float32 if net.n_classes == 1 else torch.long - n_val = len(loader) # the number of batch - tot = 0 - - with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: - for batch in loader: - imgs, true_masks = batch['image'], batch['mask'] - imgs = imgs.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=mask_type) - - with torch.no_grad(): - mask_pred = net(imgs) - - if net.n_classes > 1: - tot += F.cross_entropy(mask_pred, true_masks).item() - else: - pred = torch.sigmoid(mask_pred) - pred = (pred > 0.5).float() - tot += dice_coeff(pred, true_masks).item() - pbar.update() - - net.train() - return tot / n_val diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..053d726 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,35 @@ +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from dice_score import multiclass_dice_coeff + + +def evaluate(net, dataloader, device): + net.eval() + num_val_batches = len(dataloader) + dice_score = 0 + + # iterate over the validation set + for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): + image, mask_true = batch['image'], batch['mask'] + # move images and labels to correct device and type + image = image.to(device=device, dtype=torch.float32) + mask_true = mask_true.to(device=device, dtype=torch.long) + mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() + + with torch.no_grad(): + # predict the mask + mask_pred = net(image) + + # convert to one-hot format + if net.n_classes == 1: + mask_pred = (F.sigmoid(mask_pred) > 0).float() + else: + mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() + + # compute the Dice score, ignoring background + dice_score += multiclass_dice_coeff(mask_pred[:, :1, ...], mask_true[:, :1, ...], reduce_batch_first=False) + + net.train() + return dice_score / num_val_batches diff --git a/predict.py b/predict.py index fd12ed0..8e158db 100755 --- a/predict.py +++ b/predict.py @@ -8,9 +8,9 @@ import torch.nn.functional as F from PIL import Image from torchvision import transforms +from data_loading import BasicDataset from unet import UNet -from utils.data_vis import plot_img_and_mask -from utils.dataset import BasicDataset +from utils import plot_img_and_mask def predict_img(net, @@ -19,9 +19,7 @@ def predict_img(net, scale_factor=1, out_threshold=0.5): net.eval() - - img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) - + img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) @@ -29,94 +27,75 @@ def predict_img(net, output = net(img) if net.n_classes > 1: - probs = F.softmax(output, dim=1) + probs = F.softmax(output, dim=1)[0] else: - probs = torch.sigmoid(output) + probs = torch.sigmoid(output)[0] - probs = probs.squeeze(0) + tf = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((full_img.size[1], full_img.size[0])), + transforms.ToTensor() + ]) - tf = transforms.Compose( - [ - transforms.ToPILImage(), - transforms.Resize(full_img.size[1]), - transforms.ToTensor() - ] - ) + full_mask = tf(probs.cpu()).squeeze() - probs = tf(probs.cpu()) - full_mask = probs.squeeze().cpu().numpy() - - return full_mask > out_threshold + if net.n_classes == 1: + return (full_mask > out_threshold).numpy() + else: + return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy() def get_args(): - parser = argparse.ArgumentParser(description='Predict masks from input images', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--model', '-m', default='MODEL.pth', - metavar='FILE', - help="Specify the file in which the model is stored") - parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', - help='filenames of input images', required=True) - - parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', - help='Filenames of ouput images') + parser = argparse.ArgumentParser(description='Predict masks from input images') + parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', + help='Specify the file in which the model is stored') + parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True) + parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images') parser.add_argument('--viz', '-v', action='store_true', - help="Visualize the images as they are processed", - default=False) - parser.add_argument('--no-save', '-n', action='store_true', - help="Do not save the output masks", - default=False) - parser.add_argument('--mask-threshold', '-t', type=float, - help="Minimum probability value to consider a mask pixel white", - default=0.5) - parser.add_argument('--scale', '-s', type=float, - help="Scale factor for the input images", - default=0.5) + help='Visualize the images as they are processed') + parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks') + parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, + help='Minimum probability value to consider a mask pixel white') + parser.add_argument('--scale', '-s', type=float, default=0.5, + help='Scale factor for the input images') return parser.parse_args() def get_output_filenames(args): - in_files = args.input - out_files = [] + def _generate_name(fn): + split = os.path.splitext(fn) + return f'{split[0]}_OUT{split[1]}' - if not args.output: - for f in in_files: - pathsplit = os.path.splitext(f) - out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1])) - elif len(in_files) != len(args.output): - logging.error("Input files and output files are not of the same length") - raise SystemExit() - else: - out_files = args.output - - return out_files + return args.output or list(map(_generate_name, args.input)) -def mask_to_image(mask): - return Image.fromarray((mask * 255).astype(np.uint8)) +def mask_to_image(mask: np.ndarray): + if mask.ndim == 2: + return Image.fromarray((mask * 255).astype(np.uint8)) + elif mask.ndim == 3: + return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8)) -if __name__ == "__main__": +if __name__ == '__main__': args = get_args() in_files = args.input out_files = get_output_filenames(args) - net = UNet(n_channels=3, n_classes=1) - - logging.info("Loading model {}".format(args.model)) + net = UNet(n_channels=3, n_classes=2) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Loading model {args.model}') logging.info(f'Using device {device}') + net.to(device=device) net.load_state_dict(torch.load(args.model, map_location=device)) - logging.info("Model loaded !") + logging.info('Model loaded!') - for i, fn in enumerate(in_files): - logging.info("\nPredicting image {} ...".format(fn)) - - img = Image.open(fn) + for i, filename in enumerate(in_files): + logging.info(f'\nPredicting image {filename} ...') + img = Image.open(filename) mask = predict_img(net=net, full_img=img, @@ -125,12 +104,11 @@ if __name__ == "__main__": device=device) if not args.no_save: - out_fn = out_files[i] + out_filename = out_files[i] result = mask_to_image(mask) - result.save(out_files[i]) - - logging.info("Mask saved to {}".format(out_files[i])) + result.save(out_filename) + logging.info(f'Mask saved to {out_filename}') if args.viz: - logging.info("Visualizing results for image {}, close to continue ...".format(fn)) + logging.info(f'Visualizing results for image {filename}, close to continue...') plot_img_and_mask(img, mask) diff --git a/requirements.txt b/requirements.txt index 379669c..b653501 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,5 @@ numpy Pillow torch torchvision -tensorboard -future tqdm +wandb \ No newline at end of file diff --git a/submit.py b/submit.py deleted file mode 100644 index a5609cd..0000000 --- a/submit.py +++ /dev/null @@ -1,46 +0,0 @@ -""" Submit code specific to the kaggle challenge""" - -import os - -import torch -from PIL import Image -import numpy as np - -from predict import predict_img -from unet import UNet - -# credits to https://stackoverflow.com/users/6076729/manuel-lagunas -def rle_encode(mask_image): - pixels = mask_image.flatten() - # We avoid issues with '1' at the start or end (at the corners of - # the original image) by setting those pixels to '0' explicitly. - # We do not expect these to be non-zero for an accurate mask, - # so this should not harm the score. - pixels[0] = 0 - pixels[-1] = 0 - runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 - runs[1::2] = runs[1::2] - runs[:-1:2] - return runs - - -def submit(net): - """Used for Kaggle submission: predicts and encode all test images""" - dir = 'data/test/' - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - N = len(list(os.listdir(dir))) - with open('SUBMISSION.csv', 'a') as f: - f.write('img,rle_mask\n') - for index, i in enumerate(os.listdir(dir)): - print('{}/{}'.format(index, N)) - - img = Image.open(dir + i) - - mask = predict_img(net, img, device) - enc = rle_encode(mask) - f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) - - -if __name__ == '__main__': - net = UNet(3, 1).cuda() - net.load_state_dict(torch.load('MODEL.pth')) - submit(net) diff --git a/train.py b/train.py index c7bc600..d10df0b 100644 --- a/train.py +++ b/train.py @@ -1,187 +1,193 @@ import argparse import logging -import os import sys +from pathlib import Path -import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F +import wandb from torch import optim +from torch.utils.data import DataLoader, random_split from tqdm import tqdm -from eval import eval_net +from data_loading import BasicDataset, CarvanaDataset +from dice_score import dice_loss +from evaluate import evaluate from unet import UNet -from torch.utils.tensorboard import SummaryWriter -from utils.dataset import BasicDataset -from torch.utils.data import DataLoader, random_split - -dir_img = 'data/imgs/' -dir_mask = 'data/masks/' -dir_checkpoint = 'checkpoints/' +dir_img = Path('./data/imgs/') +dir_mask = Path('./data/masks/') +dir_checkpoint = Path('./checkpoints/') def train_net(net, device, - epochs=5, - batch_size=1, - lr=0.001, - val_percent=0.1, - save_cp=True, - img_scale=0.5): + epochs: int = 5, + batch_size: int = 1, + learning_rate: float = 0.001, + val_percent: float = 0.1, + save_checkpoint: bool = True, + img_scale: float = 0.5, + amp: bool = False): + # 1. Create dataset + try: + dataset = CarvanaDataset(dir_img, dir_mask, img_scale) + except (AssertionError, RuntimeError): + dataset = BasicDataset(dir_img, dir_mask, img_scale) - dataset = BasicDataset(dir_img, dir_mask, img_scale) + # 2. Split into train / validation partitions n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val - train, val = random_split(dataset, [n_train, n_val]) - train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) - val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True) + train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0)) - writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') - global_step = 0 + # 3. Create data loaders + loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True) + train_loader = DataLoader(train_set, shuffle=True, **loader_args) + val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args) + + # (Initialise logging) + experiment = wandb.init(project='U-Net', resume='allow', anonymous='must') + experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, + val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, + amp=amp)) logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} - Learning rate: {lr} + Learning rate: {learning_rate} Training size: {n_train} Validation size: {n_val} - Checkpoints: {save_cp} + Checkpoints: {save_checkpoint} Device: {device.type} Images scaling: {img_scale} + Mixed Precision: {amp} ''') - optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) - if net.n_classes > 1: - criterion = nn.CrossEntropyLoss() - else: - criterion = nn.BCEWithLogitsLoss() + # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP + optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score + grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) + criterion = nn.CrossEntropyLoss() + global_step = 0 + # 5. Begin training for epoch in range(epochs): net.train() - epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: - imgs = batch['image'] + images = batch['image'] true_masks = batch['mask'] - assert imgs.shape[1] == net.n_channels, \ + + assert images.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ - f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ + f'but loaded images have {images.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' - imgs = imgs.to(device=device, dtype=torch.float32) - mask_type = torch.float32 if net.n_classes == 1 else torch.long - true_masks = true_masks.to(device=device, dtype=mask_type) + images = images.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.long) - masks_pred = net(imgs) - loss = criterion(masks_pred, true_masks) + with torch.cuda.amp.autocast(enabled=amp): + masks_pred = net(images) + loss = criterion(masks_pred, true_masks) \ + + dice_loss(F.softmax(masks_pred, dim=1).float(), + F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), + multiclass=True) + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + + pbar.update(images.shape[0]) + global_step += 1 epoch_loss += loss.item() - writer.add_scalar('Loss/train', loss.item(), global_step) - + experiment.log({ + 'train loss': loss.item(), + 'step': global_step, + 'epoch': epoch + }) pbar.set_postfix(**{'loss (batch)': loss.item()}) - optimizer.zero_grad() - loss.backward() - nn.utils.clip_grad_value_(net.parameters(), 0.1) - optimizer.step() - - pbar.update(imgs.shape[0]) - global_step += 1 + # Evaluation round if global_step % (n_train // (10 * batch_size)) == 0: + histograms = {} for tag, value in net.named_parameters(): - tag = tag.replace('.', '/') - writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) - writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) - val_score = eval_net(net, val_loader, device) + tag = tag.replace('/', '.') + histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) + histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) + + val_score = evaluate(net, val_loader, device) scheduler.step(val_score) - writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) - if net.n_classes > 1: - logging.info('Validation cross entropy: {}'.format(val_score)) - writer.add_scalar('Loss/test', val_score, global_step) - else: - logging.info('Validation Dice Coeff: {}'.format(val_score)) - writer.add_scalar('Dice/test', val_score, global_step) + logging.info('Validation Dice score: {}'.format(val_score)) + experiment.log({ + 'learning rate': optimizer.param_groups[0]['lr'], + 'validation Dice': val_score, + 'images': wandb.Image(images[0].cpu()), + 'masks': { + 'true': wandb.Image(true_masks[0].float().cpu()), + 'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()), + }, + 'step': global_step, + 'epoch': epoch, + **histograms + }) - writer.add_images('images', imgs, global_step) - if net.n_classes == 1: - writer.add_images('masks/true', true_masks, global_step) - writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) - - if save_cp: - try: - os.mkdir(dir_checkpoint) - logging.info('Created checkpoint directory') - except OSError: - pass - torch.save(net.state_dict(), - dir_checkpoint + f'CP_epoch{epoch + 1}.pth') - logging.info(f'Checkpoint {epoch + 1} saved !') - - writer.close() + if save_checkpoint: + Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1))) + logging.info(f'Checkpoint {epoch + 1} saved!') def get_args(): - parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5, - help='Number of epochs', dest='epochs') - parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1, - help='Batch size', dest='batchsize') - parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001, + parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') + parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs') + parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size') + parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001, help='Learning rate', dest='lr') - parser.add_argument('-f', '--load', dest='load', type=str, default=False, - help='Load model from a .pth file') - parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5, - help='Downscaling factor of the images') - parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0, + parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') + parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') + parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, help='Percent of the data that is used as validation (0-100)') + parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') return parser.parse_args() if __name__ == '__main__': - logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') args = get_args() + + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Using device {device}') # Change here to adapt to your data # n_channels=3 for RGB images # n_classes is the number of probabilities you want to get per pixel - # - For 1 class and background, use n_classes=1 - # - For 2 classes, use n_classes=1 - # - For N > 2 classes, use n_classes=N - net = UNet(n_channels=3, n_classes=1, bilinear=True) + net = UNet(n_channels=3, n_classes=2, bilinear=True) + logging.info(f'Network:\n' f'\t{net.n_channels} input channels\n' f'\t{net.n_classes} output channels (classes)\n' f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling') if args.load: - net.load_state_dict( - torch.load(args.load, map_location=device) - ) + net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') net.to(device=device) - # faster convolutions, but more memory - # cudnn.benchmark = True - try: train_net(net=net, epochs=args.epochs, - batch_size=args.batchsize, - lr=args.lr, + batch_size=args.batch_size, + learning_rate=args.lr, device=device, img_scale=args.scale, - val_percent=args.val / 100) + val_percent=args.val / 100, + amp=args.amp) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') logging.info('Saved interrupt') - try: - sys.exit(0) - except SystemExit: - os._exit(0) + sys.exit(0) diff --git a/unet/unet_model.py b/unet/unet_model.py index 40291c5..efa7108 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -1,7 +1,5 @@ """ Full assembly of the parts to form the complete network """ -import torch.nn.functional as F - from .unet_parts import * diff --git a/unet/unet_parts.py b/unet/unet_parts.py index cf092f3..7f68f52 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -50,10 +50,9 @@ class Up(nn.Module): self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: - self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) - def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diff --git a/utils/data_vis.py b/utils.py similarity index 65% rename from utils/data_vis.py rename to utils.py index 95b9130..1d48738 100644 --- a/utils/data_vis.py +++ b/utils.py @@ -1,17 +1,17 @@ -import matplotlib.pyplot as plt - - -def plot_img_and_mask(img, mask): - classes = mask.shape[2] if len(mask.shape) > 2 else 1 - fig, ax = plt.subplots(1, classes + 1) - ax[0].set_title('Input image') - ax[0].imshow(img) - if classes > 1: - for i in range(classes): - ax[i+1].set_title(f'Output mask (class {i+1})') - ax[i+1].imshow(mask[:, :, i]) - else: - ax[1].set_title(f'Output mask') - ax[1].imshow(mask) - plt.xticks([]), plt.yticks([]) - plt.show() +import matplotlib.pyplot as plt + + +def plot_img_and_mask(img, mask): + classes = mask.shape[0] if len(mask.shape) > 2 else 1 + fig, ax = plt.subplots(1, classes + 1) + ax[0].set_title('Input image') + ax[0].imshow(img) + if classes > 1: + for i in range(classes): + ax[i + 1].set_title(f'Output mask (class {i + 1})') + ax[i + 1].imshow(mask[:, :, i]) + else: + ax[1].set_title(f'Output mask') + ax[1].imshow(mask) + plt.xticks([]), plt.yticks([]) + plt.show() diff --git a/utils/dataset.py b/utils/dataset.py deleted file mode 100644 index 4878e03..0000000 --- a/utils/dataset.py +++ /dev/null @@ -1,71 +0,0 @@ -from os.path import splitext -from os import listdir -import numpy as np -from glob import glob -import torch -from torch.utils.data import Dataset -import logging -from PIL import Image - - -class BasicDataset(Dataset): - def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''): - self.imgs_dir = imgs_dir - self.masks_dir = masks_dir - self.scale = scale - self.mask_suffix = mask_suffix - assert 0 < scale <= 1, 'Scale must be between 0 and 1' - - self.ids = [splitext(file)[0] for file in listdir(imgs_dir) - if not file.startswith('.')] - logging.info(f'Creating dataset with {len(self.ids)} examples') - - def __len__(self): - return len(self.ids) - - @classmethod - def preprocess(cls, pil_img, scale): - w, h = pil_img.size - newW, newH = int(scale * w), int(scale * h) - assert newW > 0 and newH > 0, 'Scale is too small' - pil_img = pil_img.resize((newW, newH)) - - img_nd = np.array(pil_img) - - if len(img_nd.shape) == 2: - img_nd = np.expand_dims(img_nd, axis=2) - - # HWC to CHW - img_trans = img_nd.transpose((2, 0, 1)) - if img_trans.max() > 1: - img_trans = img_trans / 255 - - return img_trans - - def __getitem__(self, i): - idx = self.ids[i] - mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*') - img_file = glob(self.imgs_dir + idx + '.*') - - assert len(mask_file) == 1, \ - f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' - assert len(img_file) == 1, \ - f'Either no image or multiple images found for the ID {idx}: {img_file}' - mask = Image.open(mask_file[0]) - img = Image.open(img_file[0]) - - assert img.size == mask.size, \ - f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' - - img = self.preprocess(img, self.scale) - mask = self.preprocess(mask, self.scale) - - return { - 'image': torch.from_numpy(img).type(torch.FloatTensor), - 'mask': torch.from_numpy(mask).type(torch.FloatTensor) - } - - -class CarvanaDataset(BasicDataset): - def __init__(self, imgs_dir, masks_dir, scale=1): - super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')