From 982dfe99d78ccd5babba84f80c21ed4fae579463 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Mon, 4 Jul 2022 21:40:38 +0200 Subject: [PATCH 01/15] feat(WIP): switching to pytorch lightning Former-commit-id: 0038dbca182717af8fc4bd846fd5be0e9fa70a9a [formerly eb5eb0717f8511bf49de8393bbdc66e727b930ff] Former-commit-id: 540304228b146fe8e086bc4ccb770a13f84cbbcb --- src/train.py | 275 ++++++---------------------------------------- src/unet/model.py | 170 +++++++++++++++++++++++++++- 2 files changed, 202 insertions(+), 243 deletions(-) diff --git a/src/train.py b/src/train.py index 9ddac8e..e8f3594 100644 --- a/src/train.py +++ b/src/train.py @@ -1,16 +1,16 @@ import logging import albumentations as A +import pytorch_lightning as pl import torch import yaml from albumentations.pytorch import ToTensorV2 +from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader -from tqdm import tqdm import wandb from src.utils.dataset import SphereDataset from unet import UNet -from utils.dice import dice_coeff from utils.paste import RandomPaste class_labels = { @@ -22,7 +22,7 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") # setup wandb - wandb.init( + logger = WandbLogger( project="U-Net", config=dict( DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", @@ -36,7 +36,7 @@ if __name__ == "__main__": AMP=True, PIN_MEMORY=True, BENCHMARK=True, - DEVICE="cuda", + DEVICE="gpu", WORKERS=8, EPOCHS=5, BATCH_SIZE=16, @@ -51,18 +51,17 @@ if __name__ == "__main__": ), ) - # create device - device = torch.device(wandb.config.DEVICE) - - # enable cudnn benchmarking - torch.backends.cudnn.benchmark = wandb.config.BENCHMARK + # seed random generators + pl.seed_everything(69420, workers=True) # 0. Create network net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) + + # log the number of parameters of the model wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) - # transfer network to device - net.to(device=device) + # log gradients and weights regularly + logger.watch(net, log="all") # 1. Create transforms tf_train = A.Compose( @@ -121,244 +120,38 @@ if __name__ == "__main__": pin_memory=wandb.config.PIN_MEMORY, ) - # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp - optimizer = torch.optim.RMSprop( - net.parameters(), - lr=wandb.config.LEARNING_RATE, - weight_decay=wandb.config.WEIGHT_DECAY, - momentum=wandb.config.MOMENTUM, + # 4. Create the trainer + trainer = pl.Trainer( + max_epochs=wandb.config.EPOCHS, + accelerator="gpu", + precision=16, + auto_scale_batch_size="binsearch", + benchmark=wandb.config.BENCHMARK, + val_check_interval=100, ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) - grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) - criterion = torch.nn.BCEWithLogitsLoss() - - # save model.onxx - dummy_input = torch.randn( - 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True - ).to(device) - torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file("checkpoints/model-0.onnx") - wandb.run.log_artifact(artifact) - - # log gradients and weights four time per epoch - wandb.watch(net, criterion, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") - # wandb init log - wandb.log( - { - "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - }, - commit=False, - ) + # # wandb init log + # wandb.log( + # { + # "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], + # }, + # commit=False, + # ) try: - for epoch in range(1, wandb.config.EPOCHS + 1): - with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: - - # Training round - for step, (images, true_masks) in enumerate(train_loader): - assert images.shape[1] == net.n_channels, ( - f"Network has been defined with {net.n_channels} input channels, " - f"but loaded images have {images.shape[1]} channels. Please check that " - "the images are loaded correctly." - ) - - # transfer images to device - images = images.to(device=device) - true_masks = true_masks.unsqueeze(1).to(device=device) - - # forward - with torch.cuda.amp.autocast(enabled=wandb.config.AMP): - pred_masks = net(images) - train_loss = criterion(pred_masks, true_masks) - - # backward - optimizer.zero_grad(set_to_none=True) - grad_scaler.scale(train_loss).backward() - grad_scaler.step(optimizer) - grad_scaler.update() - - # compute metrics - pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() - accuracy = (true_masks == pred_masks_bin).float().mean() - dice = dice_coeff(pred_masks_bin, true_masks) - mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) - - # update tqdm progress bar - pbar.update(images.shape[0]) - pbar.set_postfix(**{"loss": train_loss.item()}) - - # log metrics - wandb.log( - { - "epoch": epoch - 1 + step / len(train_loader), - "train/accuracy": accuracy, - "train/bce": train_loss, - "train/dice": dice, - "train/mae": mae, - } - ) - - if step and (step % 250 == 0 or step == len(train_loader)): - # Evaluation round - net.eval() - accuracy = 0 - val_loss = 0 - dice = 0 - mae = 0 - with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2: - for images, masks_true in val_loader: - - # transfer images to device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) - - # forward - with torch.inference_mode(): - masks_pred = net(images) - - # compute metrics - val_loss += criterion(masks_pred, masks_true) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) - - # update progress bar - pbar2.update(images.shape[0]) - - accuracy /= len(val_loader) - val_loss /= len(val_loader) - dice /= len(val_loader) - mae /= len(val_loader) - - # save the last validation batch to table - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - ) - - # log validation metrics - wandb.log( - { - "val/predictions": table, - "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - "val/accuracy": accuracy, - "val/bce": val_loss, - "val/dice": dice, - "val/mae": mae, - }, - commit=False, - ) - - # update hyperparameters - net.train() - scheduler.step(dice) - - # export model to onnx format when validation ends - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) - torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx") - wandb.run.log_artifact(artifact) - - # testing round - net.eval() - accuracy = 0 - val_loss = 0 - dice = 0 - mae = 0 - with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3: - for images, masks_true in test_loader: - - # transfer images to device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) - - # forward - with torch.inference_mode(): - masks_pred = net(images) - - # compute metrics - val_loss += criterion(masks_pred, masks_true) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) - - # update progress bar - pbar3.update(images.shape[0]) - - accuracy /= len(test_loader) - val_loss /= len(test_loader) - dice /= len(test_loader) - mae /= len(test_loader) - - # save the last validation batch to table - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - ) - - # log validation metrics - wandb.log( - { - "test/predictions": table, - "test/accuracy": accuracy, - "test/bce": val_loss, - "test/dice": dice, - "test/mae": mae, - }, - commit=False, - ) - - # stop wandb - wandb.run.finish() - + trainer.fit( + model=net, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + test_dataloaders=test_loader, + accelerator=wandb.config.DEVICE, + ) except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") raise -# sapin de noel + # stop wandb + wandb.run.finish() diff --git a/src/unet/model.py b/src/unet/model.py index 08d2807..378b407 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,9 +1,21 @@ """ Full assembly of the parts to form the complete network """ +from xmlrpc.server import list_public_methods + +import numpy as np +import pytorch_lightning as pl + +import wandb +from utils.dice import dice_coeff + from .blocks import * +class_labels = { + 1: "sphere", +} -class UNet(nn.Module): + +class UNet(pl.LightningModule): def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): super(UNet, self).__init__() self.n_channels = n_channels @@ -26,7 +38,6 @@ class UNet(nn.Module): self.outc = OutConv(features[0], n_classes) def forward(self, x): - skips = [] x = self.inc(x) @@ -41,3 +52,158 @@ class UNet(nn.Module): x = self.outc(x) return x + + @staticmethod + def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key): + table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) + + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), + ) + ): + table.add_data( + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ) + + wandb.log( + { + log_key: table, + } + ) + + def training_step(self, batch, batch_idx): + # unpacking + images, masks_true = batch + masks_true = masks_true.unsqueeze(1) + masks_pred = self(images) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + + # compute metrics + loss = F.cross_entropy(masks_pred, masks_true) + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) + + wandb.log( + { + "train/accuracy": accuracy, + "train/bce": loss, + "train/dice": dice, + "train/mae": mae, + } + ) + + return loss, dice, accuracy, mae + + def validation_step(self, batch, batch_idx): + # unpacking + images, masks_true = batch + masks_true = masks_true.unsqueeze(1) + masks_pred = self(images) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + + # compute metrics + loss = F.cross_entropy(masks_pred, masks_true) + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) + + if batch_idx == 0: + self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") + + return loss, dice, accuracy, mae + + def validation_step_end(self, validation_outputs): + # unpacking + loss, dice, accuracy, mae = validation_outputs + optimizer = self.optimizers[0] + learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + + wandb.log( + { + "train/learning_rate": learning_rate, + "val/accuracy": accuracy, + "val/bce": loss, + "val/dice": dice, + "val/mae": mae, + } + ) + + # export model to onnx + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) + torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") + artifact = wandb.Artifact("onnx", type="model") + artifact.add_file(f"checkpoints/model.onnx") + wandb.run.log_artifact(artifact) + + def test_step(self, batch, batch_idx): + # unpacking + images, masks_true = batch + masks_true = masks_true.unsqueeze(1) + masks_pred = self(images) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + + # compute metrics + loss = F.cross_entropy(masks_pred, masks_true) + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) + + if batch_idx == 0: + self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") + + return loss, dice, accuracy, mae + + def test_step_end(self, test_outputs): + # unpacking + list_loss, list_dice, list_accuracy, list_mae = test_outputs + + # averaging + loss = np.mean(list_loss) + dice = np.mean(list_dice) + accuracy = np.mean(list_accuracy) + mae = np.mean(list_mae) + + # get learning rate + optimizer = self.optimizers[0] + learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + + wandb.log( + { + "train/learning_rate": learning_rate, + "val/accuracy": accuracy, + "val/bce": loss, + "val/dice": dice, + "val/mae": mae, + } + ) + + def configure_optimizers(self): + optimizer = torch.optim.RMSprop( + self.parameters(), + lr=wandb.config.LEARNING_RATE, + weight_decay=wandb.config.WEIGHT_DECAY, + momentum=wandb.config.MOMENTUM, + ) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + "max", + patience=2, + ) + + return optimizer, scheduler From 0fb1d4fb7a52fd95d5f68d1eccf5f8324f4fe08c Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 12:06:12 +0200 Subject: [PATCH 02/15] feat(WIP): broken onnx prediction Former-commit-id: cde7623ec486cf79a710949085aadd92d8a33a3e [formerly db0f1d0b9ea536c741f23a3b683e19a9335bcd35] Former-commit-id: 7332ccb0f74c58c3a284a4568fb8f80a6d416cf4 --- .vscode/launch.json | 6 ++-- src/predict.py | 53 +++++++++++---------------- src/train.py | 40 ++++++++++++--------- src/utils/dice.py | 88 ++++++++------------------------------------- src/utils/paste.py | 2 +- 5 files changed, 64 insertions(+), 125 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 178c089..4992c79 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,9 +14,11 @@ "--input", "images/SM.png", "--output", - "output.png", + "output_onnx.png", + "--model", + "good.onnx", ], "justMyCode": true } ] -} \ No newline at end of file +} diff --git a/src/predict.py b/src/predict.py index a69f7bc..0df1bab 100755 --- a/src/predict.py +++ b/src/predict.py @@ -2,13 +2,12 @@ import argparse import logging import albumentations as A +import cv2 import numpy as np import torch from albumentations.pytorch import ToTensorV2 from PIL import Image -from unet import UNet - def get_args(): parser = argparse.ArgumentParser( @@ -38,47 +37,35 @@ def get_args(): return parser.parse_args() +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + if __name__ == "__main__": args = get_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - net = UNet(n_channels=3, n_classes=1) - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logging.info(f"Using device {device}") - - logging.info("Transfering model to device") - net.to(device=device) - - logging.info(f"Loading model {args.model}") - net.load_state_dict(torch.load(args.model, map_location=device)) + net = cv2.dnn.readNetFromONNX(args.model) + logging.info("onnx model loaded") logging.info(f"Loading image {args.input}") - img = Image.open(args.input).convert("RGB") + input_img = cv2.imread(args.input, cv2.IMREAD_COLOR) + input_img = input_img.astype(np.float32) + # input_img = cv2.resize(input_img, (512, 512)) - logging.info(f"Preprocessing image {args.input}") - tf = A.Compose( - [ - A.ToFloat(max_value=255), - ToTensorV2(), - ], + logging.info("converting to blob") + input_blob = cv2.dnn.blobFromImage( + image=input_img, + scalefactor=1 / 255, ) - aug = tf(image=np.asarray(img)) - img = aug["image"] - logging.info(f"Predicting image {args.input}") - img = img.unsqueeze(0).to(device=device, dtype=torch.float32) - - net.eval() - with torch.inference_mode(): - mask = net(img) - mask = torch.sigmoid(mask)[0] - mask = mask.cpu() - mask = mask.squeeze() - mask = mask > 0.5 - mask = np.asarray(mask) + net.setInput(input_blob) + mask = net.forward() + mask = sigmoid(mask) + mask = mask > 0.5 + mask = mask.astype(np.float32) logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask) + mask = Image.fromarray(mask, "L") mask.save(args.output) diff --git a/src/train.py b/src/train.py index 9ddac8e..2963c3b 100644 --- a/src/train.py +++ b/src/train.py @@ -5,12 +5,13 @@ import torch import yaml from albumentations.pytorch import ToTensorV2 from torch.utils.data import DataLoader +from torchmetrics import Dice from tqdm import tqdm import wandb from src.utils.dataset import SphereDataset from unet import UNet -from utils.dice import dice_coeff +from utils.dice import DiceLoss from utils.paste import RandomPaste class_labels = { @@ -37,8 +38,8 @@ if __name__ == "__main__": PIN_MEMORY=True, BENCHMARK=True, DEVICE="cuda", - WORKERS=8, - EPOCHS=5, + WORKERS=7, + EPOCHS=1001, BATCH_SIZE=16, LEARNING_RATE=1e-4, WEIGHT_DECAY=1e-8, @@ -92,9 +93,13 @@ if __name__ == "__main__": ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) - ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) + # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + + ds_train = torch.utils.data.Subset(ds_train, [0]) + ds_valid = torch.utils.data.Subset(ds_valid, [0]) + ds_test = torch.utils.data.Subset(ds_test, [0]) # 3. Create data loaders train_loader = DataLoader( @@ -131,18 +136,19 @@ if __name__ == "__main__": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() + dice_loss = DiceLoss() # save model.onxx dummy_input = torch.randn( 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True ).to(device) - torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx") + torch.onnx.export(net, dummy_input, "checkpoints/model.onnx") artifact = wandb.Artifact("onnx", type="model") artifact.add_file("checkpoints/model-0.onnx") wandb.run.log_artifact(artifact) # log gradients and weights four time per epoch - wandb.watch(net, criterion, log_freq=100) + wandb.watch(net, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") @@ -176,6 +182,8 @@ if __name__ == "__main__": pred_masks = net(images) train_loss = criterion(pred_masks, true_masks) + # compute loss + # backward optimizer.zero_grad(set_to_none=True) grad_scaler.scale(train_loss).backward() @@ -185,7 +193,7 @@ if __name__ == "__main__": # compute metrics pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() accuracy = (true_masks == pred_masks_bin).float().mean() - dice = dice_coeff(pred_masks_bin, true_masks) + dice = dice_loss.coeff(pred_masks, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) # update tqdm progress bar @@ -197,13 +205,13 @@ if __name__ == "__main__": { "epoch": epoch - 1 + step / len(train_loader), "train/accuracy": accuracy, - "train/bce": train_loss, + "train/loss": train_loss, "train/dice": dice, "train/mae": mae, } ) - if step and (step % 250 == 0 or step == len(train_loader)): + if step and (step % 100 == 0 or step == len(train_loader)): # Evaluation round net.eval() accuracy = 0 @@ -223,10 +231,10 @@ if __name__ == "__main__": # compute metrics val_loss += criterion(masks_pred, masks_true) + dice += dice_loss.coeff(pred_masks, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar2.update(images.shape[0]) @@ -267,7 +275,7 @@ if __name__ == "__main__": "val/predictions": table, "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], "val/accuracy": accuracy, - "val/bce": val_loss, + "val/loss": val_loss, "val/dice": dice, "val/mae": mae, }, @@ -276,7 +284,7 @@ if __name__ == "__main__": # update hyperparameters net.train() - scheduler.step(dice) + scheduler.step(train_loss) # export model to onnx format when validation ends dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) @@ -304,10 +312,10 @@ if __name__ == "__main__": # compute metrics val_loss += criterion(masks_pred, masks_true) + dice += dice_loss.coeff(pred_masks, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy += (masks_true == masks_pred_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar3.update(images.shape[0]) @@ -347,7 +355,7 @@ if __name__ == "__main__": { "test/predictions": table, "test/accuracy": accuracy, - "test/bce": val_loss, + "test/loss": val_loss, "test/dice": dice, "test/mae": mae, }, diff --git a/src/utils/dice.py b/src/utils/dice.py index a29f794..c4b201c 100644 --- a/src/utils/dice.py +++ b/src/utils/dice.py @@ -1,80 +1,22 @@ import torch -from torch import Tensor +import torch.nn as nn -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all batches, or for a single mask. +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceLoss, self).__init__() - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. + @staticmethod + def coeff(inputs, targets, smooth=1): + # comment out if your model contains a sigmoid or equivalent activation layer + inputs = torch.sigmoid(inputs) - Raises: - ValueError: _description_ + # flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) - Returns: - float: _description_ - """ - assert input.size() == target.size() + intersection = (inputs * targets).sum() + return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) - if input.dim() == 2 and reduce_batch_first: - raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") - - if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.reshape(-1), target.reshape(-1)) - sets_sum = torch.sum(input) + torch.sum(target) - - if sets_sum.item() == 0: - sets_sum = 2 * inter - - return (2 * inter + epsilon) / (sets_sum + epsilon) - else: - # compute and average metric for each batch element - dice = 0 - - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) - - return dice / input.shape[0] - - -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all classes. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - dice = 0 - - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) - - return dice / input.shape[1] - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float: - """Dice loss (objective to minimize) between 0 and 1. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - multiclass (bool, optional): _description_. Defaults to False. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - fn = multiclass_dice_coeff if multiclass else dice_coeff - - return 1 - fn(input, target, reduce_batch_first=True) + def forward(self, inputs, targets, smooth=1): + return 1 - self.coeff(inputs, targets, smooth) diff --git a/src/utils/paste.py b/src/utils/paste.py index 486a8ec..90ef0a8 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -24,7 +24,7 @@ class RandomPaste(A.DualTransform): nb, path_paste_img_dir, path_paste_mask_dir, - scale_range=(0.1, 0.2), + scale_range=(0.05, 0.25), always_apply=True, p=1.0, ): From 36b044c719effa98e6cacd325193d1fd53a5fd52 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 12:17:32 +0200 Subject: [PATCH 03/15] feat: working prediction but only for 512x512 Former-commit-id: c9d88ad18de91409fc1be1f1abe59d6e75ff2235 [formerly 8bf12bad1c3e8424aa26c7bd9a441facc670b059] Former-commit-id: 820e327f8a79bad36d1f15944012e77ba1ecd560 --- .vscode/launch.json | 2 +- poetry.lock | 83 ++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ src/predict.py | 48 +++++++++++++++----------- 4 files changed, 114 insertions(+), 21 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 4992c79..a0ae3f2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "console": "integratedTerminal", "args": [ "--input", - "images/SM.png", + "images/test.png", "--output", "output_onnx.png", "--model", diff --git a/poetry.lock b/poetry.lock index ec4dac1..dea7656 100644 --- a/poetry.lock +++ b/poetry.lock @@ -240,6 +240,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "flatbuffers" +version = "2.0" +description = "The FlatBuffers serialization format for Python" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "fonttools" version = "4.33.3" @@ -673,6 +681,35 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "onnx" +version = "1.12.0" +description = "Open Neural Network Exchange" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = ">=1.16.6" +protobuf = ">=3.12.2,<=3.20.1" +typing-extensions = ">=3.6.2.1" + +[package.extras] +lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"] + +[[package]] +name = "onnxruntime" +version = "1.11.1" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +flatbuffers = "*" +numpy = ">=1.21.6" +protobuf = "*" + [[package]] name = "opencv-python-headless" version = "4.6.0.66" @@ -1446,7 +1483,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5" +content-hash = "c1855a97cbe537d31526a76a49fde822e98acc89713f5f902639327b688c079a" [metadata.files] absl-py = [ @@ -1692,6 +1729,10 @@ executing = [ {file = "executing-0.8.3-py2.py3-none-any.whl", hash = "sha256:d1eef132db1b83649a3905ca6dd8897f71ac6f8cac79a7e58a1a09cf137546c9"}, {file = "executing-0.8.3.tar.gz", hash = "sha256:c6554e21c6b060590a6d3be4b82fb78f8f0194d809de5ea7df1c093763311501"}, ] +flatbuffers = [ + {file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"}, + {file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"}, +] fonttools = [ {file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"}, {file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"}, @@ -2056,6 +2097,46 @@ oauthlib = [ {file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"}, {file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"}, ] +onnx = [ + {file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"}, + {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"}, + {file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"}, + {file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"}, + {file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"}, + {file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"}, + {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"}, + {file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"}, + {file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"}, + {file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"}, + {file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"}, + {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"}, + {file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"}, + {file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"}, + {file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"}, + {file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"}, + {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"}, + {file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"}, + {file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"}, + {file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"}, + {file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"}, +] +onnxruntime = [ + {file = "onnxruntime-1.11.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:88b94a900754ef189c2b06f2046f2de8008753e0e8a3e562b2fb03298026b4b4"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:958974be7808b46815533c74e8849a2d73e73d656df8369a114ce3359f77760b"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f83e7d52932b68b08cfba4920816efc7c3177036a90116137b11888e1f2490"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-win32.whl", hash = "sha256:3106bfcd1532afcaa26fc47931f7a8770dc710263647e8fbb5f75fa5a8fc70f9"}, + {file = "onnxruntime-1.11.1-cp37-cp37m-win_amd64.whl", hash = "sha256:80775f4f64850b6774dbaa955888a89dc719cf654f1995ed5418e78c0139b5f4"}, + {file = "onnxruntime-1.11.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:b24a3cd1e6d7fe7c4c5be2996ba02ebf8beed6347b2fd3ac869d1c685a2e0264"}, + {file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73c4df0a446fe49d59629746d2163fa39b732b6afb3b5d00f8c9ec91a040e5c4"}, + {file = "onnxruntime-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149bf850c4e320e33894cab3e350a945ab17690cf54ffa00ef965273112ef614"}, + {file = "onnxruntime-1.11.1-cp38-cp38-win32.whl", hash = "sha256:9e202d7323b5728cdc3c0ee3bbc35f10cd56c7120c9626887c4ebe5d8503b488"}, + {file = "onnxruntime-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:00632fc2ee3cf86349f5b00f5385a62fe5720ef14b471c919cc2c94faeb446d0"}, + {file = "onnxruntime-1.11.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:792985ddf3d3c46efa24bcfef970e7ccd4421d46173a96ca3974dab709598591"}, + {file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ce95870b0bc7cbef5383b3c3062c6d9784af71f266192bc887928d7b927a46"}, + {file = "onnxruntime-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53121b024f68d6b16bb93bc3fb73ba05b6f55647d12054a8efae7f48ed761add"}, + {file = "onnxruntime-1.11.1-cp39-cp39-win32.whl", hash = "sha256:b90124277454c50c5b2073bb9e1368b2a5672f30c2c8f3fe01393967dcd6dce2"}, + {file = "onnxruntime-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:b1ffefc961fc607e5929fef92f3bc8bc48bd3a074b2a6448887be23eb313f75a"}, +] opencv-python-headless = [ {file = "opencv-python-headless-4.6.0.66.tar.gz", hash = "sha256:d5291d7e10aa2c19cab6fd86f0d61af8617290ecd2d7ffcb051e446868d04cc5"}, {file = "opencv_python_headless-4.6.0.66-cp36-abi3-macosx_10_15_x86_64.whl", hash = "sha256:21e70f8b0c04098cdf466d27184fe6c3820aaef944a22548db95099959c95889"}, diff --git a/pyproject.toml b/pyproject.toml index 426fef6..32eaa53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ torch = "^1.11.0" torchvision = "^0.12.0" tqdm = "^4.64.0" wandb = "^0.12.19" +onnx = "^1.12.0" +onnxruntime = "^1.11.1" [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/src/predict.py b/src/predict.py index 0df1bab..f1be604 100755 --- a/src/predict.py +++ b/src/predict.py @@ -2,8 +2,9 @@ import argparse import logging import albumentations as A -import cv2 import numpy as np +import onnx +import onnxruntime import torch from albumentations.pytorch import ToTensorV2 from PIL import Image @@ -46,26 +47,35 @@ if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - net = cv2.dnn.readNetFromONNX(args.model) - logging.info("onnx model loaded") + onnx_model = onnx.load(args.model) + onnx.checker.check_model(onnx_model) - logging.info(f"Loading image {args.input}") - input_img = cv2.imread(args.input, cv2.IMREAD_COLOR) - input_img = input_img.astype(np.float32) - # input_img = cv2.resize(input_img, (512, 512)) + ort_session = onnxruntime.InferenceSession(args.model) - logging.info("converting to blob") - input_blob = cv2.dnn.blobFromImage( - image=input_img, - scalefactor=1 / 255, + def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + + img = Image.open(args.input).convert("RGB") + + logging.info(f"Preprocessing image {args.input}") + tf = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], ) + aug = tf(image=np.asarray(img)) + img = aug["image"] - net.setInput(input_blob) - mask = net.forward() - mask = sigmoid(mask) - mask = mask > 0.5 - mask = mask.astype(np.float32) + logging.info(f"Predicting image {args.input}") + img = img.unsqueeze(0) - logging.info(f"Saving prediction to {args.output}") - mask = Image.fromarray(mask, "L") - mask.save(args.output) + # compute ONNX Runtime output prediction + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)} + ort_outs = ort_session.run(None, ort_inputs) + + img_out_y = ort_outs[0] + + img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L") + + img_out_y.save(args.output) From e4562e2481dd0a2572927b375097219115b6d9a1 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 15:18:31 +0200 Subject: [PATCH 04/15] feat: kinda broken Former-commit-id: 4cf02610721ba30c3dd1be6377daeeed907bc651 [formerly 52ef07ec8a123ddd362ac7c930eb6c915848e8b4] Former-commit-id: 29fc18cae50625fd1f2868fc9696ca505f5648e2 --- .gitignore | 1 + poetry.lock | 39 +++++++++++- pyproject.toml | 1 + src/train.py | 114 +++++++++++++---------------------- src/unet/model.py | 149 +++++++++++++++++++++++----------------------- 5 files changed, 153 insertions(+), 151 deletions(-) diff --git a/.gitignore b/.gitignore index cdc3479..40eb468 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ wandb/ images/ +lightning_logs/ checkpoints/ *.pth diff --git a/poetry.lock b/poetry.lock index ec4dac1..08ed37b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -189,6 +189,17 @@ category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "commonmark" +version = "0.9.1" +description = "Python parser for the CommonMark Markdown spec" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] + [[package]] name = "cycler" version = "0.11.0" @@ -881,7 +892,7 @@ python-versions = ">=3.6" name = "pygments" version = "2.12.0" description = "Pygments is a syntax highlighting package written in Python." -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -1027,6 +1038,22 @@ requests = ">=2.0.0" [package.extras] rsa = ["oauthlib[signedtoken] (>=3.0.0)"] +[[package]] +name = "rich" +version = "12.4.4" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" +optional = false +python-versions = ">=3.6.3,<4.0.0" + +[package.dependencies] +commonmark = ">=0.9.0,<0.10.0" +pygments = ">=2.6.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] + [[package]] name = "rsa" version = "4.8" @@ -1446,7 +1473,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.11" -content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5" +content-hash = "416650c968a0021f7d64028f272464d96319c361a72888ae4cb3e2a602873832" [metadata.files] absl-py = [ @@ -1652,6 +1679,10 @@ colorama = [ {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, ] +commonmark = [ + {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, + {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"}, +] cycler = [ {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, @@ -2419,6 +2450,10 @@ requests-oauthlib = [ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, ] +rich = [ + {file = "rich-12.4.4-py3-none-any.whl", hash = "sha256:d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec"}, + {file = "rich-12.4.4.tar.gz", hash = "sha256:4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2"}, +] rsa = [ {file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"}, {file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"}, diff --git a/pyproject.toml b/pyproject.toml index 426fef6..ef17a5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ torch = "^1.11.0" torchvision = "^0.12.0" tqdm = "^4.64.0" wandb = "^0.12.19" +rich = "^12.4.4" [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/src/train.py b/src/train.py index e8f3594..2471320 100644 --- a/src/train.py +++ b/src/train.py @@ -3,8 +3,8 @@ import logging import albumentations as A import pytorch_lightning as pl import torch -import yaml from albumentations.pytorch import ToTensorV2 +from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader @@ -13,8 +13,27 @@ from src.utils.dataset import SphereDataset from unet import UNet from utils.paste import RandomPaste -class_labels = { - 1: "sphere", +CONFIG = { + "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", + "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", + "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", + "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", + "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", + "FEATURES": [64, 128, 256, 512], + "N_CHANNELS": 3, + "N_CLASSES": 1, + "AMP": True, + "PIN_MEMORY": True, + "BENCHMARK": True, + "DEVICE": "gpu", + "WORKERS": 8, + "EPOCHS": 5, + "BATCH_SIZE": 16, + "LEARNING_RATE": 1e-4, + "WEIGHT_DECAY": 1e-8, + "MOMENTUM": 0.9, + "IMG_SIZE": 512, + "SPHERES": 5, } if __name__ == "__main__": @@ -24,28 +43,7 @@ if __name__ == "__main__": # setup wandb logger = WandbLogger( project="U-Net", - config=dict( - DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", - DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/", - DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/", - DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", - DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", - FEATURES=[64, 128, 256, 512], - N_CHANNELS=3, - N_CLASSES=1, - AMP=True, - PIN_MEMORY=True, - BENCHMARK=True, - DEVICE="gpu", - WORKERS=8, - EPOCHS=5, - BATCH_SIZE=16, - LEARNING_RATE=1e-4, - WEIGHT_DECAY=1e-8, - MOMENTUM=0.9, - IMG_SIZE=512, - SPHERES=5, - ), + config=CONFIG, settings=wandb.Settings( code_dir="./src/", ), @@ -55,10 +53,7 @@ if __name__ == "__main__": pl.seed_everything(69420, workers=True) # 0. Create network - net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) - - # log the number of parameters of the model - wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) + net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"]) # log gradients and weights regularly logger.watch(net, log="all") @@ -66,88 +61,59 @@ if __name__ == "__main__": # 1. Create transforms tf_train = A.Compose( [ - A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]), A.Flip(), A.ColorJitter(), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]), A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), ToTensorV2(), ], ) - tf_valid = A.Compose( - [ - A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) # 2. Create datasets - ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) - ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) - ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) + ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train) + ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) # 2.5. Create subset, if uncommented ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) - ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) + # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) + # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # 3. Create data loaders train_loader = DataLoader( ds_train, shuffle=True, - batch_size=wandb.config.BATCH_SIZE, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, + batch_size=CONFIG["BATCH_SIZE"], + num_workers=CONFIG["WORKERS"], + pin_memory=CONFIG["PIN_MEMORY"], ) val_loader = DataLoader( ds_valid, shuffle=False, drop_last=True, - batch_size=wandb.config.BATCH_SIZE, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, - ) - test_loader = DataLoader( - ds_test, - shuffle=False, - drop_last=False, batch_size=1, - num_workers=wandb.config.WORKERS, - pin_memory=wandb.config.PIN_MEMORY, + num_workers=CONFIG["WORKERS"], + pin_memory=CONFIG["PIN_MEMORY"], ) # 4. Create the trainer trainer = pl.Trainer( - max_epochs=wandb.config.EPOCHS, - accelerator="gpu", - precision=16, + max_epochs=CONFIG["EPOCHS"], + accelerator=CONFIG["DEVICE"], + # precision=16, auto_scale_batch_size="binsearch", - benchmark=wandb.config.BENCHMARK, + benchmark=CONFIG["BENCHMARK"], val_check_interval=100, + callbacks=RichProgressBar(), ) - # print the config - logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") - - # # wandb init log - # wandb.log( - # { - # "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - # }, - # commit=False, - # ) - try: trainer.fit( model=net, train_dataloaders=train_loader, val_dataloaders=val_loader, - test_dataloaders=test_loader, - accelerator=wandb.config.DEVICE, ) except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") diff --git a/src/unet/model.py b/src/unet/model.py index 378b407..b9d6c18 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,7 +1,5 @@ """ Full assembly of the parts to form the complete network """ -from xmlrpc.server import list_public_methods - import numpy as np import pytorch_lightning as pl @@ -40,6 +38,7 @@ class UNet(pl.LightningModule): def forward(self, x): skips = [] + x = x.to(self.device) x = self.inc(x) for down in self.downs: @@ -53,8 +52,7 @@ class UNet(pl.LightningModule): return x - @staticmethod - def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key): + def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key): table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) for i, (img, mask, pred, pred_bin) in enumerate( @@ -99,16 +97,17 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) - wandb.log( + self.log( + "train", { - "train/accuracy": accuracy, - "train/bce": loss, - "train/dice": dice, - "train/mae": mae, - } + "accuracy": accuracy, + "bce": loss, + "dice": dice, + "mae": mae, + }, ) - return loss, dice, accuracy, mae + return loss # , dice, accuracy, mae def validation_step(self, batch, batch_idx): # unpacking @@ -119,79 +118,79 @@ class UNet(pl.LightningModule): # compute metrics loss = F.cross_entropy(masks_pred, masks_true) - mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) + # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + # accuracy = (masks_true == masks_pred_bin).float().mean() + # dice = dice_coeff(masks_pred_bin, masks_true) if batch_idx == 0: self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") - return loss, dice, accuracy, mae + return loss # , dice, accuracy, mae - def validation_step_end(self, validation_outputs): - # unpacking - loss, dice, accuracy, mae = validation_outputs - optimizer = self.optimizers[0] - learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + # def validation_step_end(self, validation_outputs): + # # unpacking + # loss, dice, accuracy, mae = validation_outputs + # # optimizer = self.optimizers[0] + # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] - wandb.log( - { - "train/learning_rate": learning_rate, - "val/accuracy": accuracy, - "val/bce": loss, - "val/dice": dice, - "val/mae": mae, - } - ) + # wandb.log( + # { + # # "train/learning_rate": learning_rate, + # "val/accuracy": accuracy, + # "val/bce": loss, + # "val/dice": dice, + # "val/mae": mae, + # } + # ) - # export model to onnx - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) - torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model.onnx") - wandb.run.log_artifact(artifact) + # # export model to onnx + # dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) + # torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") + # artifact = wandb.Artifact("onnx", type="model") + # artifact.add_file(f"checkpoints/model.onnx") + # wandb.run.log_artifact(artifact) - def test_step(self, batch, batch_idx): - # unpacking - images, masks_true = batch - masks_true = masks_true.unsqueeze(1) - masks_pred = self(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + # def test_step(self, batch, batch_idx): + # # unpacking + # images, masks_true = batch + # masks_true = masks_true.unsqueeze(1) + # masks_pred = self(images) + # masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - # compute metrics - loss = F.cross_entropy(masks_pred, masks_true) - mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) + # # compute metrics + # loss = F.cross_entropy(masks_pred, masks_true) + # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + # accuracy = (masks_true == masks_pred_bin).float().mean() + # dice = dice_coeff(masks_pred_bin, masks_true) - if batch_idx == 0: - self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") + # if batch_idx == 0: + # self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") - return loss, dice, accuracy, mae + # return loss, dice, accuracy, mae - def test_step_end(self, test_outputs): - # unpacking - list_loss, list_dice, list_accuracy, list_mae = test_outputs + # def test_step_end(self, test_outputs): + # # unpacking + # list_loss, list_dice, list_accuracy, list_mae = test_outputs - # averaging - loss = np.mean(list_loss) - dice = np.mean(list_dice) - accuracy = np.mean(list_accuracy) - mae = np.mean(list_mae) + # # averaging + # loss = np.mean(list_loss) + # dice = np.mean(list_dice) + # accuracy = np.mean(list_accuracy) + # mae = np.mean(list_mae) - # get learning rate - optimizer = self.optimizers[0] - learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + # # # get learning rate + # # optimizer = self.optimizers[0] + # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] - wandb.log( - { - "train/learning_rate": learning_rate, - "val/accuracy": accuracy, - "val/bce": loss, - "val/dice": dice, - "val/mae": mae, - } - ) + # wandb.log( + # { + # # "train/learning_rate": learning_rate, + # "test/accuracy": accuracy, + # "test/bce": loss, + # "test/dice": dice, + # "test/mae": mae, + # } + # ) def configure_optimizers(self): optimizer = torch.optim.RMSprop( @@ -200,10 +199,10 @@ class UNet(pl.LightningModule): weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - "max", - patience=2, - ) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + # optimizer, + # "max", + # patience=2, + # ) - return optimizer, scheduler + return optimizer # , scheduler From 40ea1c3191a70e18e5384202aa523dca03822f0f Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 22:31:38 +0200 Subject: [PATCH 05/15] feat: working logging, auto_batch/lr still not working Former-commit-id: 29d4536eb182f84eb2cc9a4e31f31bf19a4ca272 [formerly f5fd5eec9394b81f15986fb6cbabf675b2f05c04] Former-commit-id: 3de00ee718a761221c1934b7cbaaa0ad5487856d --- src/train.py | 15 ++++-- src/unet/model.py | 124 +++++++++++++++++++++++++++------------------- 2 files changed, 84 insertions(+), 55 deletions(-) diff --git a/src/train.py b/src/train.py index 2471320..c54988a 100644 --- a/src/train.py +++ b/src/train.py @@ -19,7 +19,7 @@ CONFIG = { "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", - "FEATURES": [64, 128, 256, 512], + "FEATURES": [16, 32, 64, 128], "N_CHANNELS": 3, "N_CLASSES": 1, "AMP": True, @@ -53,7 +53,13 @@ if __name__ == "__main__": pl.seed_everything(69420, workers=True) # 0. Create network - net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"]) + net = UNet( + n_channels=CONFIG["N_CHANNELS"], + n_classes=CONFIG["N_CLASSES"], + batch_size=CONFIG["BATCH_SIZE"], + learning_rate=CONFIG["LEARNING_RATE"], + features=CONFIG["FEATURES"], + ) # log gradients and weights regularly logger.watch(net, log="all") @@ -77,7 +83,7 @@ if __name__ == "__main__": ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) @@ -104,9 +110,12 @@ if __name__ == "__main__": accelerator=CONFIG["DEVICE"], # precision=16, auto_scale_batch_size="binsearch", + auto_lr_find=True, benchmark=CONFIG["BENCHMARK"], val_check_interval=100, callbacks=RichProgressBar(), + logger=logger, + log_every_n_steps=1, ) try: diff --git a/src/unet/model.py b/src/unet/model.py index b9d6c18..be5712e 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,6 +1,5 @@ """ Full assembly of the parts to form the complete network """ -import numpy as np import pytorch_lightning as pl import wandb @@ -14,11 +13,16 @@ class_labels = { class UNet(pl.LightningModule): - def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): + def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]): super(UNet, self).__init__() + + # Hyperparameters self.n_channels = n_channels self.n_classes = n_classes + self.learning_rate = learning_rate + self.batch_size = batch_size + # Network self.inc = DoubleConv(n_channels, features[0]) self.downs = nn.ModuleList() @@ -39,6 +43,7 @@ class UNet(pl.LightningModule): skips = [] x = x.to(self.device) + x = self.inc(x) for down in self.downs: @@ -78,77 +83,97 @@ class UNet(pl.LightningModule): ), ) - wandb.log( - { - log_key: table, - } - ) + wandb.log({log_key: table}) # replace by self.log def training_step(self, batch, batch_idx): # unpacking images, masks_true = batch masks_true = masks_true.unsqueeze(1) - masks_pred = self(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - # compute metrics - loss = F.cross_entropy(masks_pred, masks_true) + # forward pass + masks_pred = self(images) + + # compute loss + bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + + # compute other metrics + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) - self.log( - "train", + self.log_dict( { - "accuracy": accuracy, - "bce": loss, - "dice": dice, - "mae": mae, + "train/accuracy": accuracy, + "train/bce": bce, + "train/dice": dice, + "train/mae": mae, }, ) - return loss # , dice, accuracy, mae + return dict( + loss=bce, + dice=dice, + accuracy=accuracy, + mae=mae, + ) def validation_step(self, batch, batch_idx): # unpacking images, masks_true = batch masks_true = masks_true.unsqueeze(1) - masks_pred = self(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - # compute metrics - loss = F.cross_entropy(masks_pred, masks_true) - # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - # accuracy = (masks_true == masks_pred_bin).float().mean() - # dice = dice_coeff(masks_pred_bin, masks_true) + # forward pass + masks_pred = self(images) + + # compute loss + bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + + # compute other metrics + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) if batch_idx == 0: self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") - return loss # , dice, accuracy, mae + return dict( + loss=bce, + dice=dice, + accuracy=accuracy, + mae=mae, + ) - # def validation_step_end(self, validation_outputs): - # # unpacking - # loss, dice, accuracy, mae = validation_outputs - # # optimizer = self.optimizers[0] - # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + def validation_epoch_end(self, validation_outputs): + # unpacking + accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean() + loss = torch.stack([d["loss"] for d in validation_outputs]).mean() + dice = torch.stack([d["dice"] for d in validation_outputs]).mean() + mae = torch.stack([d["mae"] for d in validation_outputs]).mean() - # wandb.log( - # { - # # "train/learning_rate": learning_rate, - # "val/accuracy": accuracy, - # "val/bce": loss, - # "val/dice": dice, - # "val/mae": mae, - # } - # ) + # logging + wandb.log( + { + "val/accuracy": accuracy, + "val/bce": loss, + "val/dice": dice, + "val/mae": mae, + } + ) - # # export model to onnx - # dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) - # torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") - # artifact = wandb.Artifact("onnx", type="model") - # artifact.add_file(f"checkpoints/model.onnx") - # wandb.run.log_artifact(artifact) + # export model to pth + torch.save(self.state_dict(), f"checkpoints/model.pth") + artifact = wandb.Artifact("pth", type="model") + artifact.add_file(f"checkpoints/model.pth") + wandb.run.log_artifact(artifact) + + # export model to onnx + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) + torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") + artifact = wandb.Artifact("onnx", type="model") + artifact.add_file(f"checkpoints/model.onnx") + wandb.run.log_artifact(artifact) # def test_step(self, batch, batch_idx): # # unpacking @@ -199,10 +224,5 @@ class UNet(pl.LightningModule): weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, ) - # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - # optimizer, - # "max", - # patience=2, - # ) - return optimizer # , scheduler + return optimizer From 13064521522e60be05846475db41602b2a44d74e Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 5 Jul 2022 22:52:34 +0200 Subject: [PATCH 06/15] fix: bad logging Former-commit-id: 221ec6b6bfcf4e2e616a1db688ef8e93a2bb5bfc [formerly 0e7792974319cbc693fdae3597d44f9c4c196b4d] Former-commit-id: 70f179df977715a747b981f7e2784e3c7ed88028 --- src/unet/model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/unet/model.py b/src/unet/model.py index be5712e..008d784 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -83,7 +83,10 @@ class UNet(pl.LightningModule): ), ) - wandb.log({log_key: table}) # replace by self.log + wandb.log( + {log_key: table}, + commit=False, + ) # replace by self.log def training_step(self, batch, batch_idx): # unpacking @@ -153,7 +156,7 @@ class UNet(pl.LightningModule): mae = torch.stack([d["mae"] for d in validation_outputs]).mean() # logging - wandb.log( + self.log_dict( { "val/accuracy": accuracy, "val/bce": loss, From f9ca8532a0af6a466bce266ba647b44857f8ef88 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 6 Jul 2022 11:01:07 +0200 Subject: [PATCH 07/15] feat: multiple validation image logging Former-commit-id: 84f7cfedb2688b15ab4401bcbfabcf6bacfa912b [formerly 51c2df31e934f80b1fd387fa123ccc8f9afac365] Former-commit-id: 4a06d7d2891628fe798cca24f6bb10081d40e251 --- src/unet/model.py | 74 +++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/src/unet/model.py b/src/unet/model.py index 008d784..5f613f4 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,5 +1,7 @@ """ Full assembly of the parts to form the complete network """ +import itertools + import pytorch_lightning as pl import wandb @@ -57,37 +59,6 @@ class UNet(pl.LightningModule): return x - def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key): - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.cpu(), - masks_true.cpu(), - masks_pred.cpu(), - masks_pred_bin.cpu().squeeze(1).int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, - }, - ), - ) - - wandb.log( - {log_key: table}, - commit=False, - ) # replace by self.log - def training_step(self, batch, batch_idx): # unpacking images, masks_true = batch @@ -138,24 +109,59 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) - if batch_idx == 0: - self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") + if batch_idx < 6: + rows = [] + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), + ) + ): + rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ] + ) return dict( loss=bce, dice=dice, accuracy=accuracy, mae=mae, + table_rows=rows, ) def validation_epoch_end(self, validation_outputs): - # unpacking + # matrics unpacking accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean() loss = torch.stack([d["loss"] for d in validation_outputs]).mean() dice = torch.stack([d["dice"] for d in validation_outputs]).mean() mae = torch.stack([d["mae"] for d in validation_outputs]).mean() + # table unpacking + columns = ["ID", "image", "ground truth", "prediction"] + rowss = [d["table_rows"] for d in validation_outputs] + rows = list(itertools.chain.from_iterable(rowss)) + # logging + self.logger.log_table( + key="val/predictions", + columns=columns, + data=rows, + ) self.log_dict( { "val/accuracy": accuracy, From 5a74af6cdbfc7eda980dbadb32626d7245466a93 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 6 Jul 2022 11:57:21 +0200 Subject: [PATCH 08/15] feat: automatic batc/lr guessing sorta works Former-commit-id: 346f1f55bab70df44bf15ab04c9a97f256e3d19c [formerly e027de4b57339dccc540ec11cfe81d5278c20d57] Former-commit-id: 9f3537abccca7ab3d433df318cc7acf6bfc610c4 --- .gitignore | 1 + src/train.py | 59 +++++---------------------------------------- src/unet/model.py | 61 +++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 61 insertions(+), 60 deletions(-) diff --git a/.gitignore b/.gitignore index 40eb468..d8b8e64 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ lightning_logs/ checkpoints/ *.pth *.onnx +*.ckpt *.png *.jpg diff --git a/src/train.py b/src/train.py index c54988a..656eea4 100644 --- a/src/train.py +++ b/src/train.py @@ -1,17 +1,13 @@ import logging -import albumentations as A import pytorch_lightning as pl import torch -from albumentations.pytorch import ToTensorV2 from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader import wandb -from src.utils.dataset import SphereDataset from unet import UNet -from utils.paste import RandomPaste CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", @@ -52,7 +48,7 @@ if __name__ == "__main__": # seed random generators pl.seed_everything(69420, workers=True) - # 0. Create network + # Create network net = UNet( n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], @@ -64,53 +60,13 @@ if __name__ == "__main__": # log gradients and weights regularly logger.watch(net, log="all") - # 1. Create transforms - tf_train = A.Compose( - [ - A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]), - A.Flip(), - A.ColorJitter(), - RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]), - A.GaussianBlur(), - A.ISONoise(), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - - # 2. Create datasets - ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train) - ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) - - # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) - # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) - # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) - - # 3. Create data loaders - train_loader = DataLoader( - ds_train, - shuffle=True, - batch_size=CONFIG["BATCH_SIZE"], - num_workers=CONFIG["WORKERS"], - pin_memory=CONFIG["PIN_MEMORY"], - ) - val_loader = DataLoader( - ds_valid, - shuffle=False, - drop_last=True, - batch_size=1, - num_workers=CONFIG["WORKERS"], - pin_memory=CONFIG["PIN_MEMORY"], - ) - - # 4. Create the trainer + # Create the trainer trainer = pl.Trainer( max_epochs=CONFIG["EPOCHS"], accelerator=CONFIG["DEVICE"], # precision=16, - auto_scale_batch_size="binsearch", - auto_lr_find=True, + # auto_scale_batch_size="binsearch", + # auto_lr_find=True, benchmark=CONFIG["BENCHMARK"], val_check_interval=100, callbacks=RichProgressBar(), @@ -119,11 +75,8 @@ if __name__ == "__main__": ) try: - trainer.fit( - model=net, - train_dataloaders=train_loader, - val_dataloaders=val_loader, - ) + trainer.tune(net) + trainer.fit(model=net) except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") raise diff --git a/src/unet/model.py b/src/unet/model.py index 5f613f4..ccc036c 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -2,10 +2,15 @@ import itertools +import albumentations as A import pytorch_lightning as pl +from albumentations.pytorch import ToTensorV2 +from torch.utils.data import DataLoader import wandb +from src.utils.dataset import SphereDataset from utils.dice import dice_coeff +from utils.paste import RandomPaste from .blocks import * @@ -24,6 +29,9 @@ class UNet(pl.LightningModule): self.learning_rate = learning_rate self.batch_size = batch_size + # log hyperparameters + self.save_hyperparameters() + # Network self.inc = DoubleConv(n_channels, features[0]) @@ -59,6 +67,42 @@ class UNet(pl.LightningModule): return x + def train_dataloader(self): + tf_train = A.Compose( + [ + A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + A.Flip(), + A.ColorJitter(), + RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + A.GaussianBlur(), + A.ISONoise(), + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + + ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) + ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) + + return DataLoader( + ds_train, + batch_size=self.batch_size, + shuffle=True, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) + + def val_dataloader(self): + ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) + + return DataLoader( + ds_valid, + shuffle=False, + batch_size=1, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) + def training_step(self, batch, batch_idx): # unpacking images, masks_true = batch @@ -109,8 +153,8 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) + rows = [] if batch_idx < 6: - rows = [] for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(), @@ -157,11 +201,14 @@ class UNet(pl.LightningModule): rows = list(itertools.chain.from_iterable(rowss)) # logging - self.logger.log_table( - key="val/predictions", - columns=columns, - data=rows, - ) + try: + self.logger.log_table( + key="val/predictions", + columns=columns, + data=rows, + ) + except: + pass self.log_dict( { "val/accuracy": accuracy, @@ -229,7 +276,7 @@ class UNet(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.RMSprop( self.parameters(), - lr=wandb.config.LEARNING_RATE, + lr=self.learning_rate, weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, ) From b71b57285f106c58f5b0e384bca7908a2b9bc7d0 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 6 Jul 2022 14:27:26 +0200 Subject: [PATCH 09/15] feat: using dice_loss feat: paste aug contrast/sharpness Former-commit-id: 93f19e9643858a81ace14e9a697dfb6b3cca4d47 [formerly f6ef5f65e84f37b4b55a99a49442b7d30d6d3911] Former-commit-id: 2f49a81340a91ab7456d093a849ed294457f8a83 --- src/train.py | 2 +- src/unet/model.py | 81 ++++++++++++-------------------------------- src/utils/dice.py | 84 ++++++---------------------------------------- src/utils/paste.py | 16 ++++++--- 4 files changed, 46 insertions(+), 137 deletions(-) diff --git a/src/train.py b/src/train.py index 656eea4..d97802d 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,7 @@ CONFIG = { "BENCHMARK": True, "DEVICE": "gpu", "WORKERS": 8, - "EPOCHS": 5, + "EPOCHS": 10, "BATCH_SIZE": 16, "LEARNING_RATE": 1e-4, "WEIGHT_DECAY": 1e-8, diff --git a/src/unet/model.py b/src/unet/model.py index ccc036c..4e87827 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader import wandb from src.utils.dataset import SphereDataset -from utils.dice import dice_coeff +from utils.dice import dice_loss from utils.paste import RandomPaste from .blocks import * @@ -111,28 +111,29 @@ class UNet(pl.LightningModule): # forward pass masks_pred = self(images) - # compute loss + # compute metrics bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + dice = dice_loss(masks_pred, masks_true) - # compute other metrics masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False) mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) self.log_dict( { "train/accuracy": accuracy, - "train/bce": bce, "train/dice": dice, + "train/dice_bin": dice_bin, + "train/bce": bce, "train/mae": mae, }, ) return dict( - loss=bce, - dice=dice, accuracy=accuracy, + loss=dice, + bce=bce, mae=mae, ) @@ -144,17 +145,17 @@ class UNet(pl.LightningModule): # forward pass masks_pred = self(images) - # compute loss + # compute metrics bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + dice = dice_loss(masks_pred, masks_true) - # compute other metrics masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False) mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy = (masks_true == masks_pred_bin).float().mean() - dice = dice_coeff(masks_pred_bin, masks_true) rows = [] - if batch_idx < 6: + if batch_idx % 50 == 0: for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(), @@ -181,9 +182,10 @@ class UNet(pl.LightningModule): ) return dict( - loss=bce, - dice=dice, accuracy=accuracy, + loss=dice, + dice_bin=dice_bin, + bce=bce, mae=mae, table_rows=rows, ) @@ -191,8 +193,9 @@ class UNet(pl.LightningModule): def validation_epoch_end(self, validation_outputs): # matrics unpacking accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean() + dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean() loss = torch.stack([d["loss"] for d in validation_outputs]).mean() - dice = torch.stack([d["dice"] for d in validation_outputs]).mean() + bce = torch.stack([d["bce"] for d in validation_outputs]).mean() mae = torch.stack([d["mae"] for d in validation_outputs]).mean() # table unpacking @@ -201,7 +204,7 @@ class UNet(pl.LightningModule): rows = list(itertools.chain.from_iterable(rowss)) # logging - try: + try: # required by autofinding, logger replaced by dummy self.logger.log_table( key="val/predictions", columns=columns, @@ -209,11 +212,13 @@ class UNet(pl.LightningModule): ) except: pass + self.log_dict( { "val/accuracy": accuracy, - "val/bce": loss, - "val/dice": dice, + "val/dice": loss, + "val/dice_bin": dice_bin, + "val/bce": bce, "val/mae": mae, } ) @@ -231,48 +236,6 @@ class UNet(pl.LightningModule): artifact.add_file(f"checkpoints/model.onnx") wandb.run.log_artifact(artifact) - # def test_step(self, batch, batch_idx): - # # unpacking - # images, masks_true = batch - # masks_true = masks_true.unsqueeze(1) - # masks_pred = self(images) - # masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - - # # compute metrics - # loss = F.cross_entropy(masks_pred, masks_true) - # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - # accuracy = (masks_true == masks_pred_bin).float().mean() - # dice = dice_coeff(masks_pred_bin, masks_true) - - # if batch_idx == 0: - # self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") - - # return loss, dice, accuracy, mae - - # def test_step_end(self, test_outputs): - # # unpacking - # list_loss, list_dice, list_accuracy, list_mae = test_outputs - - # # averaging - # loss = np.mean(list_loss) - # dice = np.mean(list_dice) - # accuracy = np.mean(list_accuracy) - # mae = np.mean(list_mae) - - # # # get learning rate - # # optimizer = self.optimizers[0] - # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] - - # wandb.log( - # { - # # "train/learning_rate": learning_rate, - # "test/accuracy": accuracy, - # "test/bce": loss, - # "test/dice": dice, - # "test/mae": mae, - # } - # ) - def configure_optimizers(self): optimizer = torch.optim.RMSprop( self.parameters(), diff --git a/src/utils/dice.py b/src/utils/dice.py index a29f794..acdc4b2 100644 --- a/src/utils/dice.py +++ b/src/utils/dice.py @@ -1,80 +1,18 @@ import torch -from torch import Tensor -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all batches, or for a single mask. +def dice_score(inputs, targets, smooth=1, logits=True): + # comment out if your model contains a sigmoid or equivalent activation layer + if logits: + inputs = torch.sigmoid(inputs) - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. + # flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) - Raises: - ValueError: _description_ - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - if input.dim() == 2 and reduce_batch_first: - raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") - - if input.dim() == 2 or reduce_batch_first: - inter = torch.dot(input.reshape(-1), target.reshape(-1)) - sets_sum = torch.sum(input) + torch.sum(target) - - if sets_sum.item() == 0: - sets_sum = 2 * inter - - return (2 * inter + epsilon) / (sets_sum + epsilon) - else: - # compute and average metric for each batch element - dice = 0 - - for i in range(input.shape[0]): - dice += dice_coeff(input[i, ...], target[i, ...]) - - return dice / input.shape[0] + intersection = (inputs * targets).sum() + return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: - """Average of Dice coefficient for all classes. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - reduce_batch_first (bool, optional): _description_. Defaults to False. - epsilon (_type_, optional): _description_. Defaults to 1e-6. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - dice = 0 - - for channel in range(input.shape[1]): - dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) - - return dice / input.shape[1] - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float: - """Dice loss (objective to minimize) between 0 and 1. - - Args: - input (Tensor): _description_ - target (Tensor): _description_ - multiclass (bool, optional): _description_. Defaults to False. - - Returns: - float: _description_ - """ - assert input.size() == target.size() - - fn = multiclass_dice_coeff if multiclass else dice_coeff - - return 1 - fn(input, target, reduce_batch_first=True) +def dice_loss(inputs, targets, smooth=1, logits=True): + return 1 - dice_score(inputs, targets, smooth, logits) diff --git a/src/utils/paste.py b/src/utils/paste.py index 486a8ec..a1e24e4 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -93,6 +93,18 @@ class RandomPaste(A.DualTransform): target_shape = np.array(target_img.shape[:2], dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint) + # change paste_img's brightness randomly + filter = ImageEnhance.Brightness(paste_img) + paste_img = filter.enhance(rd.uniform(0.5, 1.5)) + + # change paste_img's contrast randomly + filter = ImageEnhance.Contrast(paste_img) + paste_img = filter.enhance(rd.uniform(0.5, 1.5)) + + # change paste_img's sharpness randomly + filter = ImageEnhance.Sharpness(paste_img) + paste_img = filter.enhance(rd.uniform(0.5, 1.5)) + # compute the minimum scaling to fit inside target image min_scale = np.min(target_shape / paste_shape) @@ -117,10 +129,6 @@ class RandomPaste(A.DualTransform): # update paste_shape after scaling paste_shape = np.array(paste_img.size, dtype=np.uint) - # change brightness randomly - filter = ImageEnhance.Brightness(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - # generate some positions positions = [] NB = rd.randint(1, self.nb) From 0dd606144fbe1e5cc6a6bdf800cf74d8e65408c7 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 7 Jul 2022 12:06:41 +0200 Subject: [PATCH 10/15] feat: new paste dataset Former-commit-id: 039874208d5a27bf01beb2746a77502fd836ae5c [formerly 66638fcabaea1044d9a2fd48e6ffb20f149ebf47] Former-commit-id: 6bdf8bba0b3cbd8706337aa3167c36fba8855a4c --- comp.ipynb.REMOVED.git-id | 2 +- extract.ipynb | 177 ++++++++++++++++++++++++++++++++++++++ src/train.py | 14 ++- src/unet/model.py | 6 +- src/utils/paste.py | 28 +++--- 5 files changed, 206 insertions(+), 21 deletions(-) create mode 100644 extract.ipynb diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index b439b71..3c6779c 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d \ No newline at end of file +3c9a34f197340a6051eb34d11695c7d6b72164f0 \ No newline at end of file diff --git a/extract.ipynb b/extract.ipynb new file mode 100644 index 0000000..f5c6750 --- /dev/null +++ b/extract.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import numpy as np\n", + "\n", + "import albumentations as A\n", + "\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-07T10:16:03.003643\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "img = Image.open(\"/tmp/extract/photo.jpg\").convert(\"RGBA\")\n", + "mask = Image.open(\"/tmp/extract/MASK.PNG\").convert(\"LA\")\n", + "\n", + "plt.figure(figsize=(18, 10))\n", + "\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(img)\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(mask)\n", + "\n", + "ax = plt.gca()\n", + "ax.set_facecolor('xkcd:salmon')\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n \n \n \n \n 2022-07-07T10:22:39.796028\n image/svg+xml\n \n \n Matplotlib v3.5.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(434, 425)\n" + ] + } + ], + "source": [ + "box = mask.getbbox()\n", + "\n", + "crop_img = img.crop(box)\n", + "crop_mask = mask.crop(box)\n", + "\n", + "plt.figure(figsize=(18, 10))\n", + "\n", + "plt.subplot(2, 2, 1)\n", + "plt.imshow(crop_img)\n", + "\n", + "plt.subplot(2, 2, 2)\n", + "plt.imshow(crop_mask)\n", + "\n", + "ax = plt.gca()\n", + "ax.set_facecolor('xkcd:salmon')\n", + "\n", + "empty = Image.fromarray(np.zeros(crop_img.size), \"RGBA\")\n", + "empty.paste(crop_img, crop_mask)\n", + "\n", + "plt.subplot(2, 2, 3)\n", + "plt.imshow(empty.resize((100, 100)))\n", + "\n", + "plt.subplot(2, 2, 4)\n", + "plt.imshow(crop_mask.resize((100, 100)))\n", + "\n", + "ax = plt.gca()\n", + "ax.set_facecolor('xkcd:salmon')\n", + "\n", + "plt.show()\n", + "\n", + "print(crop_img.size)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "spheres_to_extract_dir = \"/home/lilian/data_disk/lfainsin/test/\"\n", + "\n", + "spheres = list(Path(spheres_to_extract_dir).glob(\"**/*.jpg\"))\n", + "\n", + "parents = [path.parent for path in spheres]\n", + "parents = set(parents)\n", + "\n", + "for parent in parents:\n", + " mask_path = parent.joinpath(\"MASK.PNG\")\n", + " mask = Image.open(mask_path).convert(\"LA\")\n", + " box = mask.getbbox()\n", + " crop_mask = mask.crop(box)\n", + "\n", + " filename = Path(\"/tmp/saves/\" + str(mask_path).strip(spheres_to_extract_dir))\n", + " filename.parent.mkdir(parents=True, exist_ok=True)\n", + " crop_mask.save(filename)\n", + "\n", + " spheres = list(parent.glob(\"*.jpg\"))\n", + " for sphere in spheres:\n", + " img = Image.open(sphere).convert(\"RGB\")\n", + " crop_img = img.crop(box)\n", + "\n", + " filename = Path(\"/tmp/saves/\" + str(sphere).strip(spheres_to_extract_dir))\n", + " filename.parent.mkdir(parents=True, exist_ok=True)\n", + " crop_img.save(filename)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('.venv': poetry)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/train.py b/src/train.py index d97802d..eb1b115 100644 --- a/src/train.py +++ b/src/train.py @@ -15,16 +15,22 @@ CONFIG = { "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", - "FEATURES": [16, 32, 64, 128], + # "FEATURES": [1, 2, 4, 8], + # "FEATURES": [4, 8, 16, 32], + # "FEATURES": [8, 16, 32, 64], + # "FEATURES": [4, 8, 16, 32, 64], + "FEATURES": [8, 16, 32, 64, 128], + # "FEATURES": [16, 32, 64, 128], + # "FEATURES": [64, 128, 256, 512], "N_CHANNELS": 3, "N_CLASSES": 1, "AMP": True, "PIN_MEMORY": True, "BENCHMARK": True, "DEVICE": "gpu", - "WORKERS": 8, - "EPOCHS": 10, - "BATCH_SIZE": 16, + "WORKERS": 10, + "EPOCHS": 1, + "BATCH_SIZE": 32, "LEARNING_RATE": 1e-4, "WEIGHT_DECAY": 1e-8, "MOMENTUM": 0.9, diff --git a/src/unet/model.py b/src/unet/model.py index 4e87827..4508b27 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -82,7 +82,7 @@ class UNet(pl.LightningModule): ) ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) + # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) return DataLoader( ds_train, @@ -178,6 +178,8 @@ class UNet(pl.LightningModule): }, }, ), + dice, + dice_bin, ] ) @@ -199,7 +201,7 @@ class UNet(pl.LightningModule): mae = torch.stack([d["mae"] for d in validation_outputs]).mean() # table unpacking - columns = ["ID", "image", "ground truth", "prediction"] + columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] rowss = [d["table_rows"] for d in validation_outputs] rows = list(itertools.chain.from_iterable(rowss)) diff --git a/src/utils/paste.py b/src/utils/paste.py index a1e24e4..a25289b 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -1,5 +1,6 @@ import os import random as rd +from pathlib import Path import albumentations as A import numpy as np @@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform): def __init__( self, nb, - path_paste_img_dir, - path_paste_mask_dir, + image_dir, scale_range=(0.1, 0.2), always_apply=True, p=1.0, ): super().__init__(always_apply, p) - self.path_paste_img_dir = path_paste_img_dir - self.path_paste_mask_dir = path_paste_mask_dir + self.images = [] + self.images.extend(list(Path(image_dir).glob("**/*.jpg"))) + self.images.extend(list(Path(image_dir).glob("**/*.png"))) self.scale_range = scale_range self.nb = nb @@ -69,14 +70,15 @@ class RandomPaste(A.DualTransform): return False def get_params_dependent_on_targets(self, params): - # choose a random image inside the image folder - filename = rd.choice(os.listdir(self.path_paste_img_dir)) + # choose a random image and its corresponding mask + img_path = rd.choice(self.images) + mask_path = img_path.parent.joinpath("MASK.PNG") # load the "paste" image paste_img = Image.open( os.path.join( self.path_paste_img_dir, - filename, + img_path, ) ).convert("RGBA") @@ -84,25 +86,23 @@ class RandomPaste(A.DualTransform): paste_mask = Image.open( os.path.join( self.path_paste_mask_dir, - filename, + mask_path, ) ).convert("LA") # load the target image target_img = params["image"] + + # compute shapes, for easier computations target_shape = np.array(target_img.shape[:2], dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint) - # change paste_img's brightness randomly - filter = ImageEnhance.Brightness(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - # change paste_img's contrast randomly filter = ImageEnhance.Contrast(paste_img) paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - # change paste_img's sharpness randomly - filter = ImageEnhance.Sharpness(paste_img) + # change paste_img's brightness randomly + filter = ImageEnhance.Brightness(paste_img) paste_img = filter.enhance(rd.uniform(0.5, 1.5)) # compute the minimum scaling to fit inside target image From 92058da1d50a8dfe590715790f1696c3e2ca8736 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 7 Jul 2022 13:40:00 +0200 Subject: [PATCH 11/15] feat: fuck yeah Former-commit-id: d472cf758e1761df9f15d1f5c7448cc4274d089f [formerly 4cd21f3ca4f0e8e22f22f61ac76e5ed4478e6937] Former-commit-id: 9b2b0d5dd11d7b7804b1e64d52fcff2d2ea43a0b --- comp.ipynb.REMOVED.git-id | 2 +- src/train.py | 7 +++---- src/unet/model.py | 2 +- src/utils/paste.py | 18 ++---------------- 4 files changed, 7 insertions(+), 22 deletions(-) diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index 3c6779c..6ca72da 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -3c9a34f197340a6051eb34d11695c7d6b72164f0 \ No newline at end of file +5ef2ef54312186cd3e3162869c4f237b69de3b1e \ No newline at end of file diff --git a/src/train.py b/src/train.py index eb1b115..1f72d51 100644 --- a/src/train.py +++ b/src/train.py @@ -13,13 +13,12 @@ CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", - "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", - "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", + "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/", # "FEATURES": [1, 2, 4, 8], # "FEATURES": [4, 8, 16, 32], - # "FEATURES": [8, 16, 32, 64], + "FEATURES": [8, 16, 32, 64], # "FEATURES": [4, 8, 16, 32, 64], - "FEATURES": [8, 16, 32, 64, 128], + # "FEATURES": [8, 16, 32, 64, 128], # "FEATURES": [16, 32, 64, 128], # "FEATURES": [64, 128, 256, 512], "N_CHANNELS": 3, diff --git a/src/unet/model.py b/src/unet/model.py index 4508b27..8d5288f 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -73,7 +73,7 @@ class UNet(pl.LightningModule): A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), A.Flip(), A.ColorJitter(), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), diff --git a/src/utils/paste.py b/src/utils/paste.py index a25289b..0c4db00 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -1,4 +1,3 @@ -import os import random as rd from pathlib import Path @@ -75,20 +74,10 @@ class RandomPaste(A.DualTransform): mask_path = img_path.parent.joinpath("MASK.PNG") # load the "paste" image - paste_img = Image.open( - os.path.join( - self.path_paste_img_dir, - img_path, - ) - ).convert("RGBA") + paste_img = Image.open(img_path).convert("RGBA") # load its respective mask - paste_mask = Image.open( - os.path.join( - self.path_paste_mask_dir, - mask_path, - ) - ).convert("LA") + paste_mask = Image.open(mask_path).convert("LA") # load the target image target_img = params["image"] @@ -151,6 +140,3 @@ class RandomPaste(A.DualTransform): ) return params - - def get_transform_init_args_names(self): - return "scale_range", "path_paste_img_dir", "path_paste_mask_dir" From 8611d8cd7a372a40df31a940e3567c50013902a1 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Thu, 7 Jul 2022 16:31:53 +0200 Subject: [PATCH 12/15] feat: better paste augmentation Former-commit-id: 2adef7920e5f317ac3fbe0205862e29d49c2af8f [formerly 41cb0c231b00a1e992847723eb754af1a9e28eee] Former-commit-id: f826c62f4aa3b0c9d2ea7b49f49b5839072ff259 --- comp.ipynb.REMOVED.git-id | 2 +- src/train.py | 6 -- src/utils/paste.py | 157 +++++++++++++++++++++++--------------- 3 files changed, 97 insertions(+), 68 deletions(-) diff --git a/comp.ipynb.REMOVED.git-id b/comp.ipynb.REMOVED.git-id index 6ca72da..21ecac1 100644 --- a/comp.ipynb.REMOVED.git-id +++ b/comp.ipynb.REMOVED.git-id @@ -1 +1 @@ -5ef2ef54312186cd3e3162869c4f237b69de3b1e \ No newline at end of file +0f3136c724eea42fdf1ee15e721ef33604e9a46d \ No newline at end of file diff --git a/src/train.py b/src/train.py index 1f72d51..bb0c2e4 100644 --- a/src/train.py +++ b/src/train.py @@ -14,13 +14,7 @@ CONFIG = { "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/", - # "FEATURES": [1, 2, 4, 8], - # "FEATURES": [4, 8, 16, 32], "FEATURES": [8, 16, 32, 64], - # "FEATURES": [4, 8, 16, 32, 64], - # "FEATURES": [8, 16, 32, 64, 128], - # "FEATURES": [16, 32, 64, 128], - # "FEATURES": [64, 128, 256, 512], "N_CHANNELS": 3, "N_CLASSES": 1, "AMP": True, diff --git a/src/utils/paste.py b/src/utils/paste.py index 0c4db00..bf662eb 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -3,7 +3,8 @@ from pathlib import Path import albumentations as A import numpy as np -from PIL import Image, ImageEnhance +import torchvision.transforms as T +from PIL import Image class RandomPaste(A.DualTransform): @@ -38,105 +39,139 @@ class RandomPaste(A.DualTransform): def targets_as_params(self): return ["image"] - def apply(self, img, positions, paste_img, paste_mask, **params): + def apply(self, img, augmentations, paste_img, paste_mask, **params): # convert img to Image, needed for `paste` function img = Image.fromarray(img) + # copy paste_img and paste_mask + paste_mask = paste_mask.copy() + paste_img = paste_img.copy() + # paste spheres - for pos in positions: - img.paste(paste_img, pos, paste_mask) + for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations: + paste_img = T.functional.adjust_contrast( + paste_img, + contrast_factor=contrast, + ) + paste_img = T.functional.adjust_brightness( + paste_img, + brightness_factor=brightness, + ) + paste_img = T.functional.affine( + paste_img, + scale=0.95, + angle=angle, + translate=(0, 0), + shear=(shearx, sheary), + interpolation=T.InterpolationMode.BICUBIC, + ) + paste_img = T.functional.resize( + paste_img, + size=shape, + interpolation=T.InterpolationMode.BICUBIC, + ) + + paste_mask = T.functional.affine( + paste_mask, + scale=0.95, + angle=angle, + translate=(0, 0), + shear=(shearx, sheary), + interpolation=T.InterpolationMode.BICUBIC, + ) + paste_mask = T.functional.resize( + paste_mask, + size=shape, + interpolation=T.InterpolationMode.BICUBIC, + ) + + img.paste(paste_img, (x, y), paste_mask) return np.asarray(img.convert("RGB")) - def apply_to_mask(self, mask, positions, paste_mask, **params): + def apply_to_mask(self, mask, augmentations, paste_mask, **params): # convert mask to Image, needed for `paste` function mask = Image.fromarray(mask) - # binarize the mask -> {0, 1} - paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0) + # copy paste_img and paste_mask + paste_mask = paste_mask.copy() - # paste spheres - for pos in positions: - mask.paste(paste_mask, pos, paste_mask_bin) + for (x, y, shearx, sheary, shape, angle, _, _) in augmentations: + paste_mask = T.functional.affine( + paste_mask, + scale=0.95, + angle=angle, + translate=(0, 0), + shear=(shearx, sheary), + interpolation=T.InterpolationMode.BICUBIC, + ) + paste_mask = T.functional.resize( + paste_mask, + size=shape, + interpolation=T.InterpolationMode.BICUBIC, + ) + + # binarize the mask -> {0, 1} + paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0) + + mask.paste(paste_mask, (x, y), paste_mask_bin) return np.asarray(mask.convert("L")) - @staticmethod - def overlap(positions, x1, y1, w, h): - for x2, y2 in positions: - if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h: - return True - return False - def get_params_dependent_on_targets(self, params): # choose a random image and its corresponding mask img_path = rd.choice(self.images) mask_path = img_path.parent.joinpath("MASK.PNG") - # load the "paste" image + # load images (w/ transparency) paste_img = Image.open(img_path).convert("RGBA") - - # load its respective mask paste_mask = Image.open(mask_path).convert("LA") - - # load the target image target_img = params["image"] - # compute shapes, for easier computations + # compute shapes target_shape = np.array(target_img.shape[:2], dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint) - # change paste_img's contrast randomly - filter = ImageEnhance.Contrast(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - - # change paste_img's brightness randomly - filter = ImageEnhance.Brightness(paste_img) - paste_img = filter.enhance(rd.uniform(0.5, 1.5)) - - # compute the minimum scaling to fit inside target image + # compute minimum scaling to fit inside target min_scale = np.min(target_shape / paste_shape) - # randomize the relative scaling - scale = rd.uniform(*self.scale_range) - - # rotate the image and its mask - angle = rd.uniform(0, 360) - paste_img = paste_img.rotate(angle, expand=True) - paste_mask = paste_mask.rotate(angle, expand=True) - - # scale the "paste" image and its mask - paste_img = paste_img.resize( - tuple((paste_shape * min_scale * scale).astype(np.uint)), - resample=Image.Resampling.LANCZOS, - ) - paste_mask = paste_mask.resize( - tuple((paste_shape * min_scale * scale).astype(np.uint)), - resample=Image.Resampling.LANCZOS, - ) - - # update paste_shape after scaling - paste_shape = np.array(paste_img.size, dtype=np.uint) - - # generate some positions - positions = [] + # generate augmentations + augmentations = [] NB = rd.randint(1, self.nb) - while len(positions) < NB: - x = rd.randint(0, target_shape[0] - paste_shape[0]) - y = rd.randint(0, target_shape[1] - paste_shape[1]) + while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max + scale = rd.uniform(*self.scale_range) * min_scale + shape = np.array(paste_shape * scale, dtype=np.uint) + + x = rd.randint(0, target_shape[0] - shape[0]) + y = rd.randint(0, target_shape[1] - shape[1]) # check for overlapping - if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]): + if RandomPaste.overlap(augmentations, x, y, shape[0], shape[1]): continue - positions.append((x, y)) + shearx = rd.uniform(-2, 2) + sheary = rd.uniform(-2, 2) + + angle = rd.uniform(0, 360) + + brightness = rd.uniform(0.8, 1.2) + contrast = rd.uniform(0.8, 1.2) + + augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast)) params.update( { - "positions": positions, + "augmentations": augmentations, "paste_img": paste_img, "paste_mask": paste_mask, } ) return params + + @staticmethod + def overlap(positions, x1, y1, w, h): + for x2, y2, _, _, _, _, _, _ in positions: + if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h: + return True + return False From 90978bfdc3221d7b6796f00f3537b6b6bb26da8d Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 8 Jul 2022 09:54:45 +0200 Subject: [PATCH 13/15] feat: ugly training image logging Former-commit-id: 16a25008320f436069cff9f44bf013c1c2d0f890 [formerly 683afc2cb6322ce3f1d98797b947cca8c6af09a4] Former-commit-id: a5dae735e10107b514f028e84084ce7a303216ef --- src/train.py | 3 +-- src/unet/blocks.py | 2 +- src/unet/model.py | 42 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index bb0c2e4..d111e38 100644 --- a/src/train.py +++ b/src/train.py @@ -4,7 +4,6 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger -from torch.utils.data import DataLoader import wandb from unet import UNet @@ -13,7 +12,7 @@ CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", - "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/", + "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres/", "FEATURES": [8, 16, 32, 64], "N_CHANNELS": 3, "N_CLASSES": 1, diff --git a/src/unet/blocks.py b/src/unet/blocks.py index 1f4a854..d125002 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -1,4 +1,4 @@ -""" Parts of the U-Net model """ +"""Parts of the U-Net model.""" import torch import torch.nn as nn diff --git a/src/unet/model.py b/src/unet/model.py index 8d5288f..11ddc65 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -130,6 +130,46 @@ class UNet(pl.LightningModule): }, ) + if batch_idx == 22000: + rows = [] + columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), + ) + ): + rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + dice, + dice_bin, + ] + ) + + # logging + try: # required by autofinding, logger replaced by dummy + self.logger.log_table( + key="train/predictions", + columns=columns, + data=rows, + ) + except: + pass + return dict( accuracy=accuracy, loss=dice, @@ -155,7 +195,7 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() rows = [] - if batch_idx % 50 == 0: + if batch_idx % 50 == 0 or dice < 0.1: for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(), From 7ca448803b9487a165129228b33124e5a765373c Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 8 Jul 2022 11:05:44 +0200 Subject: [PATCH 14/15] feat: dynamic onnx ? Former-commit-id: 619c74a13d0674fc77bd5c1bf711013c1b3d4626 [formerly 762126125c2f108855a0837f3688f28e1002dcf7] Former-commit-id: 7aa443fd8b68603171d2bbfa87bb9eddbe6dc066 --- .../comp.ipynb.REMOVED.git-id | 0 src/dynamic.ipynb | 105 ++++++++++++++++++ extract.ipynb => src/extract.ipynb | 0 src/unet/blocks.py | 8 +- src/unet/model.py | 4 +- 5 files changed, 114 insertions(+), 3 deletions(-) rename comp.ipynb.REMOVED.git-id => src/comp.ipynb.REMOVED.git-id (100%) create mode 100644 src/dynamic.ipynb rename extract.ipynb => src/extract.ipynb (100%) diff --git a/comp.ipynb.REMOVED.git-id b/src/comp.ipynb.REMOVED.git-id similarity index 100% rename from comp.ipynb.REMOVED.git-id rename to src/comp.ipynb.REMOVED.git-id diff --git a/src/dynamic.ipynb b/src/dynamic.ipynb new file mode 100644 index 0000000..371cb92 --- /dev/null +++ b/src/dynamic.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from unet import UNet\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "net = UNet(\n", + " n_channels=3,\n", + " n_classes=1,\n", + " batch_size=1,\n", + " learning_rate=1e-4,\n", + " features=[8, 16, 32, 64],\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "net.load_state_dict(\n", + " torch.load(\"../best.pth\")\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n", + "torch.onnx.export(\n", + " net,\n", + " dummy_input,\n", + " \"model-test.onnx\",\n", + " opset_version=14,\n", + " input_names=[\"input\"],\n", + " output_names=[\"output\"],\n", + " dynamic_axes={\n", + " \"input\": {\n", + " 2: \"height\",\n", + " 3: \"width\",\n", + " },\n", + " \"output\": {\n", + " 2: \"height\",\n", + " 3: \"width\",\n", + " },\n", + " },\n", + ")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.0 ('.venv': poetry)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/extract.ipynb b/src/extract.ipynb similarity index 100% rename from extract.ipynb rename to src/extract.ipynb diff --git a/src/unet/blocks.py b/src/unet/blocks.py index d125002..0df7f5f 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -59,8 +59,14 @@ class Up(nn.Module): # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] + diffY2 = torch.div(diffY, 2, rounding_mode="trunc") + diffX2 = torch.div(diffX, 2, rounding_mode="trunc") + + x1 = F.pad( + input=x1, + pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2], + ) - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) diff --git a/src/unet/model.py b/src/unet/model.py index 11ddc65..c244735 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -268,14 +268,14 @@ class UNet(pl.LightningModule): # export model to pth torch.save(self.state_dict(), f"checkpoints/model.pth") artifact = wandb.Artifact("pth", type="model") - artifact.add_file(f"checkpoints/model.pth") + artifact.add_file("checkpoints/model.pth") wandb.run.log_artifact(artifact) # export model to onnx dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model.onnx") + artifact.add_file("checkpoints/model.onnx") wandb.run.log_artifact(artifact) def configure_optimizers(self): From a53fa4ac92b90dc1233ed5e27f6f8bf84c6e355e Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 8 Jul 2022 11:21:39 +0200 Subject: [PATCH 15/15] chore: testing with only real pasted sphers Former-commit-id: 6805c289a6bb55d782fedef5233d90f985f19da3 [formerly 2c25bcbd9238002466a341bc131ff636d8d689aa] Former-commit-id: e9de7f78afeb6c3e38b6e7cb392ac3beb25ffb0b --- src/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index d111e38..a1e54c8 100644 --- a/src/train.py +++ b/src/train.py @@ -12,7 +12,7 @@ CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", - "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres/", + "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/realspheres/", "FEATURES": [8, 16, 32, 64], "N_CHANNELS": 3, "N_CLASSES": 1,