diff --git a/eval.py b/eval.py index e261532..f293bd1 100644 --- a/eval.py +++ b/eval.py @@ -10,20 +10,22 @@ def eval_net(net, loader, device, n_val): net.eval() tot = 0 - for i, b in tqdm(enumerate(loader), desc='Validation round', unit='img'): - imgs = b['image'] - true_masks = b['mask'] + with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: + for batch in loader: + imgs = batch['image'] + true_masks = batch['mask'] - imgs = imgs.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.float32) + imgs = imgs.to(device=device, dtype=torch.float32) + true_masks = true_masks.to(device=device, dtype=torch.float32) - mask_pred = net(imgs) + mask_pred = net(imgs) - for true_mask in true_masks: - mask_pred = (mask_pred > 0.5).float() - if net.n_classes > 1: - tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() - else: - tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + for true_mask in true_masks: + mask_pred = (mask_pred > 0.5).float() + if net.n_classes > 1: + tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() + else: + tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item() + pbar.update(imgs.shape[0]) return tot / n_val diff --git a/predict.py b/predict.py index dcb879e..16b7b1c 100755 --- a/predict.py +++ b/predict.py @@ -9,8 +9,10 @@ from torchvision import transforms import torch.nn.functional as F from unet import UNet -from utils import plot_img_and_mask +from utils.data_vis import plot_img_and_mask from utils.dataset import BasicDataset +from utils.crf import dense_crf + def predict_img(net, full_img, diff --git a/submit.py b/submit.py index 49acd88..65a0df5 100644 --- a/submit.py +++ b/submit.py @@ -4,10 +4,23 @@ import os import torch from PIL import Image +import numpy as np from predict import predict_img from unet import UNet -from utils import rle_encode + +# credits to https://stackoverflow.com/users/6076729/manuel-lagunas +def rle_encode(mask_image): + pixels = mask_image.flatten() + # We avoid issues with '1' at the start or end (at the corners of + # the original image) by setting those pixels to '0' explicitly. + # We do not expect these to be non-zero for an accurate mask, + # so this should not harm the score. + pixels[0] = 0 + pixels[-1] = 0 + runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 + runs[1::2] = runs[1::2] - runs[:-1:2] + return runs def submit(net, gpu=False): diff --git a/train.py b/train.py index ebc80f6..55f3a82 100644 --- a/train.py +++ b/train.py @@ -11,8 +11,8 @@ from tqdm import tqdm from eval import eval_net from unet import UNet -from utils import get_ids, split_train_val, get_imgs_and_masks, batch +from torch.utils.tensorboard import SummaryWriter from utils.dataset import BasicDataset from torch.utils.data import DataLoader, random_split @@ -26,7 +26,7 @@ def train_net(net, epochs=5, batch_size=1, lr=0.1, - val_percent=0.15, + val_percent=0.1, save_cp=True, img_scale=0.5): @@ -34,8 +34,11 @@ def train_net(net, n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) - train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4) - val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4) + train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) + val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) + + writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') + global_step = 0 logging.info(f'''Starting training: Epochs: {epochs} @@ -48,7 +51,7 @@ def train_net(net, Images scaling: {img_scale} ''') - optimizer = optim.Adam(net.parameters(), lr=lr) + optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: @@ -78,6 +81,7 @@ def train_net(net, masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() + writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)': loss.item()}) @@ -85,7 +89,22 @@ def train_net(net, loss.backward() optimizer.step() - pbar.update(batch_size) + pbar.update(imgs.shape[0]) + global_step += 1 + if global_step % (len(dataset) // (10 * batch_size)) == 0: + val_score = eval_net(net, val_loader, device, n_val) + if net.n_classes > 1: + logging.info('Validation cross entropy: {}'.format(val_score)) + writer.add_scalar('Loss/test', val_score, global_step) + + else: + logging.info('Validation Dice Coeff: {}'.format(val_score)) + writer.add_scalar('Dice/test', val_score, global_step) + + writer.add_images('images', imgs, global_step) + if net.n_classes == 1: + writer.add_images('masks/true', true_masks, global_step) + writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: try: @@ -97,12 +116,7 @@ def train_net(net, dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') - val_score = eval_net(net, val_loader, device, n_val) - if net.n_classes > 1: - logging.info('Validation cross entropy: {}'.format(val_score)) - - else: - logging.info('Validation Dice Coeff: {}'.format(val_score)) + writer.close() def get_args(): @@ -118,7 +132,7 @@ def get_args(): help='Load model from a .pth file') parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5, help='Downscaling factor of the images') - parser.add_argument('-v', '--validation', dest='val', type=float, default=15.0, + parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0, help='Percent of the data that is used as validation (0-100)') return parser.parse_args() diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index 54e2c6f..0000000 --- a/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .crf import * -from .load import * -from .utils import * -from .data_vis import * diff --git a/utils/dataset.py b/utils/dataset.py index 69d922a..57c069a 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -25,6 +25,7 @@ class BasicDataset(Dataset): def preprocess(self, pil_img): w, h = pil_img.size newW, newH = int(self.scale * w), int(self.scale * h) + assert newW > 0 and newH > 0, 'Scale is too small' pil_img = pil_img.resize((newW, newH)) img_nd = np.array(pil_img) diff --git a/utils/utils.py b/utils/utils.py deleted file mode 100644 index eb2f855..0000000 --- a/utils/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np - -# credits to https://stackoverflow.com/users/6076729/manuel-lagunas -def rle_encode(mask_image): - pixels = mask_image.flatten() - # We avoid issues with '1' at the start or end (at the corners of - # the original image) by setting those pixels to '0' explicitly. - # We do not expect these to be non-zero for an accurate mask, - # so this should not harm the score. - pixels[0] = 0 - pixels[-1] = 0 - runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 - runs[1::2] = runs[1::2] - runs[:-1:2] - return runs