Merge branch 'pl'

Former-commit-id: 2cff12df68a22afe6b9e1133702203c4c6d86d24 [formerly 4f606e6e20e260cf94c1384ceb65c29c9ac12164]
Former-commit-id: ee4167081531b498b295c0402260da58cc4f7374
This commit is contained in:
Laurent Fainsin 2022-07-08 11:23:02 +02:00
commit 164417b13f
14 changed files with 2683 additions and 1949 deletions

2
.gitignore vendored
View file

@ -4,10 +4,12 @@ __pycache__/
wandb/ wandb/
images/ images/
lightning_logs/
checkpoints/ checkpoints/
*.pth *.pth
*.onnx *.onnx
*.ckpt
*.png *.png
*.jpg *.jpg

6
.vscode/launch.json vendored
View file

@ -12,9 +12,11 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"args": [ "args": [
"--input", "--input",
"images/SM.png", "images/test.png",
"--output", "--output",
"output.png", "output_onnx.png",
"--model",
"good.onnx",
], ],
"justMyCode": true "justMyCode": true
} }

View file

@ -1 +0,0 @@
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d

3454
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -8,8 +8,11 @@ version = "0.1.0"
albumentations = "^1.2.0" albumentations = "^1.2.0"
matplotlib = "^3.5.2" matplotlib = "^3.5.2"
numpy = "^1.23.0" numpy = "^1.23.0"
onnx = "^1.12.0"
onnxruntime = "^1.11.1"
python = ">=3.8,<3.11" python = ">=3.8,<3.11"
pytorch-lightning = "^1.6.4" pytorch-lightning = "^1.6.4"
rich = "^12.4.4"
scipy = "^1.8.1" scipy = "^1.8.1"
torch = "^1.11.0" torch = "^1.11.0"
torchvision = "^0.12.0" torchvision = "^0.12.0"

View file

@ -0,0 +1 @@
0f3136c724eea42fdf1ee15e721ef33604e9a46d

105
src/dynamic.ipynb Normal file
View file

@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from unet import UNet\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"net = UNet(\n",
" n_channels=3,\n",
" n_classes=1,\n",
" batch_size=1,\n",
" learning_rate=1e-4,\n",
" features=[8, 16, 32, 64],\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.load_state_dict(\n",
" torch.load(\"../best.pth\")\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
"torch.onnx.export(\n",
" net,\n",
" dummy_input,\n",
" \"model-test.onnx\",\n",
" opset_version=14,\n",
" input_names=[\"input\"],\n",
" output_names=[\"output\"],\n",
" dynamic_axes={\n",
" \"input\": {\n",
" 2: \"height\",\n",
" 3: \"width\",\n",
" },\n",
" \"output\": {\n",
" 2: \"height\",\n",
" 3: \"width\",\n",
" },\n",
" },\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.0 ('.venv': poetry)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

177
src/extract.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -3,12 +3,12 @@ import logging
import albumentations as A import albumentations as A
import numpy as np import numpy as np
import onnx
import onnxruntime
import torch import torch
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image from PIL import Image
from unet import UNet
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -38,23 +38,23 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1) onnx_model = onnx.load(args.model)
onnx.checker.check_model(onnx_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ort_session = onnxruntime.InferenceSession(args.model)
logging.info(f"Using device {device}")
logging.info("Transfering model to device") def to_numpy(tensor):
net.to(device=device) return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB") img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}") logging.info(f"Preprocessing image {args.input}")
@ -68,17 +68,14 @@ if __name__ == "__main__":
img = aug["image"] img = aug["image"]
logging.info(f"Predicting image {args.input}") logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0).to(device=device, dtype=torch.float32) img = img.unsqueeze(0)
net.eval() # compute ONNX Runtime output prediction
with torch.inference_mode(): ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
mask = net(img) ort_outs = ort_session.run(None, ort_inputs)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
logging.info(f"Saving prediction to {args.output}") img_out_y = ort_outs[0]
mask = Image.fromarray(mask)
mask.save(args.output) img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
img_out_y.save(args.output)

View file

@ -1,20 +1,33 @@
import logging import logging
import albumentations as A import pytorch_lightning as pl
import torch import torch
import yaml from pytorch_lightning.callbacks import RichProgressBar
from albumentations.pytorch import ToTensorV2 from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb import wandb
from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.dice import dice_coeff
from utils.paste import RandomPaste
class_labels = { CONFIG = {
1: "sphere", "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/realspheres/",
"FEATURES": [8, 16, 32, 64],
"N_CHANNELS": 3,
"N_CLASSES": 1,
"AMP": True,
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 10,
"EPOCHS": 1,
"BATCH_SIZE": 32,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
"IMG_SIZE": 512,
"SPHERES": 5,
} }
if __name__ == "__main__": if __name__ == "__main__":
@ -22,266 +35,49 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb # setup wandb
wandb.init( logger = WandbLogger(
project="U-Net", project="U-Net",
config=dict( config=CONFIG,
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[16, 32, 64, 128],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
PIN_MEMORY=True,
BENCHMARK=True,
DEVICE="cuda",
WORKERS=7,
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8,
MOMENTUM=0.9,
IMG_SIZE=512,
SPHERES=5,
),
settings=wandb.Settings( settings=wandb.Settings(
code_dir="./src/", code_dir="./src/",
), ),
) )
# create device # seed random generators
device = torch.device(wandb.config.DEVICE) pl.seed_everything(69420, workers=True)
# enable cudnn benchmarking # Create network
torch.backends.cudnn.benchmark = wandb.config.BENCHMARK net = UNet(
n_channels=CONFIG["N_CHANNELS"],
# 0. Create network n_classes=CONFIG["N_CLASSES"],
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) batch_size=CONFIG["BATCH_SIZE"],
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) learning_rate=CONFIG["LEARNING_RATE"],
features=CONFIG["FEATURES"],
# transfer network to device
net.to(device=device)
# 1. Create transforms
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
tf_valid = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.ToFloat(max_value=255),
ToTensorV2(),
],
) )
# 2. Create datasets # log gradients and weights regularly
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) logger.watch(net, log="all")
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
# 2.5. Create subset, if uncommented # Create the trainer
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"],
# 3. Create data loaders accelerator=CONFIG["DEVICE"],
train_loader = DataLoader( # precision=16,
ds_train, # auto_scale_batch_size="binsearch",
shuffle=True, # auto_lr_find=True,
batch_size=wandb.config.BATCH_SIZE, benchmark=CONFIG["BENCHMARK"],
num_workers=wandb.config.WORKERS, val_check_interval=100,
pin_memory=wandb.config.PIN_MEMORY, callbacks=RichProgressBar(),
) logger=logger,
val_loader = DataLoader( log_every_n_steps=1,
ds_valid,
shuffle=False,
drop_last=True,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.RMSprop(
net.parameters(),
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
# save model.onxx
dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device)
torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
# log gradients and weights four time per epoch
wandb.watch(net, log_freq=100)
# print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
# wandb init log
wandb.log(
{
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
},
commit=False,
) )
try: try:
global_step = 0 trainer.tune(net)
for epoch in range(1, wandb.config.EPOCHS + 1): trainer.fit(model=net)
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
# Training round
for step, (images, true_masks) in enumerate(train_loader):
assert images.shape[1] == net.n_channels, (
f"Network has been defined with {net.n_channels} input channels, "
f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
# transfer images to device
images = images.to(device=device)
true_masks = true_masks.unsqueeze(1).to(device=device)
# forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
# backward
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
# compute metrics
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar
pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()})
# log metrics
wandb.log(
{
"train/accuracy": accuracy,
"train/bce": train_loss,
"train/dice": dice,
"train/mae": mae,
}
)
global_step += 1
if global_step % 100 == 0:
# Evaluation round
net.eval()
accuracy = 0
val_loss = 0
dice = 0
mae = 0
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
for images, masks_true in val_loader:
# transfer images to device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
# forward
with torch.inference_mode():
masks_pred = net(images)
# compute metrics
val_loss += criterion(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bars
pbar2.update(images.shape[0])
pbar.refresh()
accuracy /= len(val_loader)
val_loss /= len(val_loader)
dice /= len(val_loader)
mae /= len(val_loader)
# save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
table.add_data(
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
)
# log validation metrics
wandb.log(
{
"val/predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice,
"val/mae": mae,
},
commit=False,
)
# update hyperparameters
net.train()
scheduler.step(dice)
# export model to onnx format when validation ends
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
wandb.run.log_artifact(artifact)
# stop wandb
wandb.run.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")
raise raise
# sapin de noel # stop wandb
wandb.run.finish()

View file

@ -1,4 +1,4 @@
""" Parts of the U-Net model """ """Parts of the U-Net model."""
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -59,8 +59,14 @@ class Up(nn.Module):
# input is CHW # input is CHW
diffY = x2.size()[2] - x1.size()[2] diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] diffX = x2.size()[3] - x1.size()[3]
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
x1 = F.pad(
input=x1,
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1) x = torch.cat([x2, x1], dim=1)
return self.conv(x) return self.conv(x)

View file

@ -1,14 +1,38 @@
""" Full assembly of the parts to form the complete network """ """ Full assembly of the parts to form the complete network """
import itertools
import albumentations as A
import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import wandb
from src.utils.dataset import SphereDataset
from utils.dice import dice_loss
from utils.paste import RandomPaste
from .blocks import * from .blocks import *
class_labels = {
1: "sphere",
}
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): class UNet(pl.LightningModule):
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
super(UNet, self).__init__() super(UNet, self).__init__()
# Hyperparameters
self.n_channels = n_channels self.n_channels = n_channels
self.n_classes = n_classes self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
# Network
self.inc = DoubleConv(n_channels, features[0]) self.inc = DoubleConv(n_channels, features[0])
self.downs = nn.ModuleList() self.downs = nn.ModuleList()
@ -28,6 +52,8 @@ class UNet(nn.Module):
def forward(self, x): def forward(self, x):
skips = [] skips = []
x = x.to(self.device)
x = self.inc(x) x = self.inc(x)
for down in self.downs: for down in self.downs:
@ -40,3 +66,224 @@ class UNet(nn.Module):
x = self.outc(x) x = self.outc(x)
return x return x
def train_dataloader(self):
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
return DataLoader(
ds_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
return DataLoader(
ds_valid,
shuffle=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def training_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
# forward pass
masks_pred = self(images)
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
self.log_dict(
{
"train/accuracy": accuracy,
"train/dice": dice,
"train/dice_bin": dice_bin,
"train/bce": bce,
"train/mae": mae,
},
)
if batch_idx == 22000:
rows = []
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="train/predictions",
columns=columns,
data=rows,
)
except:
pass
return dict(
accuracy=accuracy,
loss=dice,
bce=bce,
mae=mae,
)
def validation_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
# forward pass
masks_pred = self(images)
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
rows = []
if batch_idx % 50 == 0 or dice < 0.1:
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
return dict(
accuracy=accuracy,
loss=dice,
dice_bin=dice_bin,
bce=bce,
mae=mae,
table_rows=rows,
)
def validation_epoch_end(self, validation_outputs):
# matrics unpacking
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
# table unpacking
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
rowss = [d["table_rows"] for d in validation_outputs]
rows = list(itertools.chain.from_iterable(rowss))
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="val/predictions",
columns=columns,
data=rows,
)
except:
pass
self.log_dict(
{
"val/accuracy": accuracy,
"val/dice": loss,
"val/dice_bin": dice_bin,
"val/bce": bce,
"val/mae": mae,
}
)
# export model to pth
torch.save(self.state_dict(), f"checkpoints/model.pth")
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("checkpoints/model.pth")
wandb.run.log_artifact(artifact)
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=self.learning_rate,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
return optimizer

View file

@ -1,80 +1,18 @@
import torch import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: def dice_score(inputs, targets, smooth=1, logits=True):
"""Average of Dice coefficient for all batches, or for a single mask. # comment out if your model contains a sigmoid or equivalent activation layer
if logits:
inputs = torch.sigmoid(inputs)
Args: # flatten label and prediction tensors
input (Tensor): _description_ inputs = inputs.view(-1)
target (Tensor): _description_ targets = targets.view(-1)
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Raises: intersection = (inputs * targets).sum()
ValueError: _description_ return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
Returns:
float: _description_
"""
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: def dice_loss(inputs, targets, smooth=1, logits=True):
"""Average of Dice coefficient for all classes. return 1 - dice_score(inputs, targets, smooth, logits)
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Returns:
float: _description_
"""
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
"""Dice loss (objective to minimize) between 0 and 1.
Args:
input (Tensor): _description_
target (Tensor): _description_
multiclass (bool, optional): _description_. Defaults to False.
Returns:
float: _description_
"""
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)

View file

@ -1,9 +1,10 @@
import os
import random as rd import random as rd
from pathlib import Path
import albumentations as A import albumentations as A
import numpy as np import numpy as np
from PIL import Image, ImageEnhance import torchvision.transforms as T
from PIL import Image
class RandomPaste(A.DualTransform): class RandomPaste(A.DualTransform):
@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform):
def __init__( def __init__(
self, self,
nb, nb,
path_paste_img_dir, image_dir,
path_paste_mask_dir, scale_range=(0.05, 0.25),
scale_range=(0.1, 0.2),
always_apply=True, always_apply=True,
p=1.0, p=1.0,
): ):
super().__init__(always_apply, p) super().__init__(always_apply, p)
self.path_paste_img_dir = path_paste_img_dir self.images = []
self.path_paste_mask_dir = path_paste_mask_dir self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
self.images.extend(list(Path(image_dir).glob("**/*.png")))
self.scale_range = scale_range self.scale_range = scale_range
self.nb = nb self.nb = nb
@ -38,105 +39,129 @@ class RandomPaste(A.DualTransform):
def targets_as_params(self): def targets_as_params(self):
return ["image"] return ["image"]
def apply(self, img, positions, paste_img, paste_mask, **params): def apply(self, img, augmentations, paste_img, paste_mask, **params):
# convert img to Image, needed for `paste` function # convert img to Image, needed for `paste` function
img = Image.fromarray(img) img = Image.fromarray(img)
# copy paste_img and paste_mask
paste_mask = paste_mask.copy()
paste_img = paste_img.copy()
# paste spheres # paste spheres
for pos in positions: for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations:
img.paste(paste_img, pos, paste_mask) paste_img = T.functional.adjust_contrast(
paste_img,
contrast_factor=contrast,
)
paste_img = T.functional.adjust_brightness(
paste_img,
brightness_factor=brightness,
)
paste_img = T.functional.affine(
paste_img,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
interpolation=T.InterpolationMode.BICUBIC,
)
paste_img = T.functional.resize(
paste_img,
size=shape,
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask = T.functional.affine(
paste_mask,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask = T.functional.resize(
paste_mask,
size=shape,
interpolation=T.InterpolationMode.BICUBIC,
)
img.paste(paste_img, (x, y), paste_mask)
return np.asarray(img.convert("RGB")) return np.asarray(img.convert("RGB"))
def apply_to_mask(self, mask, positions, paste_mask, **params): def apply_to_mask(self, mask, augmentations, paste_mask, **params):
# convert mask to Image, needed for `paste` function # convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask) mask = Image.fromarray(mask)
# binarize the mask -> {0, 1} # copy paste_img and paste_mask
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0) paste_mask = paste_mask.copy()
# paste spheres for (x, y, shearx, sheary, shape, angle, _, _) in augmentations:
for pos in positions: paste_mask = T.functional.affine(
mask.paste(paste_mask, pos, paste_mask_bin) paste_mask,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask = T.functional.resize(
paste_mask,
size=shape,
interpolation=T.InterpolationMode.BICUBIC,
)
# binarize the mask -> {0, 1}
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
mask.paste(paste_mask, (x, y), paste_mask_bin)
return np.asarray(mask.convert("L")) return np.asarray(mask.convert("L"))
@staticmethod
def overlap(positions, x1, y1, w, h):
for x2, y2 in positions:
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
return True
return False
def get_params_dependent_on_targets(self, params): def get_params_dependent_on_targets(self, params):
# choose a random image inside the image folder # choose a random image and its corresponding mask
filename = rd.choice(os.listdir(self.path_paste_img_dir)) img_path = rd.choice(self.images)
mask_path = img_path.parent.joinpath("MASK.PNG")
# load the "paste" image # load images (w/ transparency)
paste_img = Image.open( paste_img = Image.open(img_path).convert("RGBA")
os.path.join( paste_mask = Image.open(mask_path).convert("LA")
self.path_paste_img_dir,
filename,
)
).convert("RGBA")
# load its respective mask
paste_mask = Image.open(
os.path.join(
self.path_paste_mask_dir,
filename,
)
).convert("LA")
# load the target image
target_img = params["image"] target_img = params["image"]
# compute shapes
target_shape = np.array(target_img.shape[:2], dtype=np.uint) target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint) paste_shape = np.array(paste_img.size, dtype=np.uint)
# compute the minimum scaling to fit inside target image # compute minimum scaling to fit inside target
min_scale = np.min(target_shape / paste_shape) min_scale = np.min(target_shape / paste_shape)
# randomize the relative scaling # generate augmentations
scale = rd.uniform(*self.scale_range) augmentations = []
# rotate the image and its mask
angle = rd.uniform(0, 360)
paste_img = paste_img.rotate(angle, expand=True)
paste_mask = paste_mask.rotate(angle, expand=True)
# scale the "paste" image and its mask
paste_img = paste_img.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
paste_mask = paste_mask.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
# update paste_shape after scaling
paste_shape = np.array(paste_img.size, dtype=np.uint)
# change brightness randomly
filter = ImageEnhance.Brightness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# generate some positions
positions = []
NB = rd.randint(1, self.nb) NB = rd.randint(1, self.nb)
while len(positions) < NB: while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max
x = rd.randint(0, target_shape[0] - paste_shape[0]) scale = rd.uniform(*self.scale_range) * min_scale
y = rd.randint(0, target_shape[1] - paste_shape[1]) shape = np.array(paste_shape * scale, dtype=np.uint)
x = rd.randint(0, target_shape[0] - shape[0])
y = rd.randint(0, target_shape[1] - shape[1])
# check for overlapping # check for overlapping
if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]): if RandomPaste.overlap(augmentations, x, y, shape[0], shape[1]):
continue continue
positions.append((x, y)) shearx = rd.uniform(-2, 2)
sheary = rd.uniform(-2, 2)
angle = rd.uniform(0, 360)
brightness = rd.uniform(0.8, 1.2)
contrast = rd.uniform(0.8, 1.2)
augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast))
params.update( params.update(
{ {
"positions": positions, "augmentations": augmentations,
"paste_img": paste_img, "paste_img": paste_img,
"paste_mask": paste_mask, "paste_mask": paste_mask,
} }
@ -144,5 +169,9 @@ class RandomPaste(A.DualTransform):
return params return params
def get_transform_init_args_names(self): @staticmethod
return "scale_range", "path_paste_img_dir", "path_paste_mask_dir" def overlap(positions, x1, y1, w, h):
for x2, y2, _, _, _, _, _, _ in positions:
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
return True
return False