Merge branch 'pl'

Former-commit-id: 2cff12df68a22afe6b9e1133702203c4c6d86d24 [formerly 4f606e6e20e260cf94c1384ceb65c29c9ac12164]
Former-commit-id: ee4167081531b498b295c0402260da58cc4f7374
This commit is contained in:
Laurent Fainsin 2022-07-08 11:23:02 +02:00
commit 164417b13f
14 changed files with 2683 additions and 1949 deletions

2
.gitignore vendored
View file

@ -4,10 +4,12 @@ __pycache__/
wandb/
images/
lightning_logs/
checkpoints/
*.pth
*.onnx
*.ckpt
*.png
*.jpg

8
.vscode/launch.json vendored
View file

@ -12,11 +12,13 @@
"console": "integratedTerminal",
"args": [
"--input",
"images/SM.png",
"images/test.png",
"--output",
"output.png",
"output_onnx.png",
"--model",
"good.onnx",
],
"justMyCode": true
}
]
}
}

View file

@ -1 +0,0 @@
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d

3454
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -8,8 +8,11 @@ version = "0.1.0"
albumentations = "^1.2.0"
matplotlib = "^3.5.2"
numpy = "^1.23.0"
onnx = "^1.12.0"
onnxruntime = "^1.11.1"
python = ">=3.8,<3.11"
pytorch-lightning = "^1.6.4"
rich = "^12.4.4"
scipy = "^1.8.1"
torch = "^1.11.0"
torchvision = "^0.12.0"

View file

@ -0,0 +1 @@
0f3136c724eea42fdf1ee15e721ef33604e9a46d

105
src/dynamic.ipynb Normal file
View file

@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from unet import UNet\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"net = UNet(\n",
" n_channels=3,\n",
" n_classes=1,\n",
" batch_size=1,\n",
" learning_rate=1e-4,\n",
" features=[8, 16, 32, 64],\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.load_state_dict(\n",
" torch.load(\"../best.pth\")\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
"torch.onnx.export(\n",
" net,\n",
" dummy_input,\n",
" \"model-test.onnx\",\n",
" opset_version=14,\n",
" input_names=[\"input\"],\n",
" output_names=[\"output\"],\n",
" dynamic_axes={\n",
" \"input\": {\n",
" 2: \"height\",\n",
" 3: \"width\",\n",
" },\n",
" \"output\": {\n",
" 2: \"height\",\n",
" 3: \"width\",\n",
" },\n",
" },\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.0 ('.venv': poetry)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

177
src/extract.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -3,12 +3,12 @@ import logging
import albumentations as A
import numpy as np
import onnx
import onnxruntime
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from unet import UNet
def get_args():
parser = argparse.ArgumentParser(
@ -38,23 +38,23 @@ def get_args():
return parser.parse_args()
def sigmoid(x):
return 1 / (1 + np.exp(-x))
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1)
onnx_model = onnx.load(args.model)
onnx.checker.check_model(onnx_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
ort_session = onnxruntime.InferenceSession(args.model)
logging.info("Transfering model to device")
net.to(device=device)
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}")
@ -68,17 +68,14 @@ if __name__ == "__main__":
img = aug["image"]
logging.info(f"Predicting image {args.input}")
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
img = img.unsqueeze(0)
net.eval()
with torch.inference_mode():
mask = net(img)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
ort_outs = ort_session.run(None, ort_inputs)
logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask)
mask.save(args.output)
img_out_y = ort_outs[0]
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
img_out_y.save(args.output)

View file

@ -1,20 +1,33 @@
import logging
import albumentations as A
import pytorch_lightning as pl
import torch
import yaml
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from tqdm import tqdm
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import wandb
from src.utils.dataset import SphereDataset
from unet import UNet
from utils.dice import dice_coeff
from utils.paste import RandomPaste
class_labels = {
1: "sphere",
CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/realspheres/",
"FEATURES": [8, 16, 32, 64],
"N_CHANNELS": 3,
"N_CLASSES": 1,
"AMP": True,
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 10,
"EPOCHS": 1,
"BATCH_SIZE": 32,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
"IMG_SIZE": 512,
"SPHERES": 5,
}
if __name__ == "__main__":
@ -22,266 +35,49 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb
wandb.init(
logger = WandbLogger(
project="U-Net",
config=dict(
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[16, 32, 64, 128],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
PIN_MEMORY=True,
BENCHMARK=True,
DEVICE="cuda",
WORKERS=7,
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8,
MOMENTUM=0.9,
IMG_SIZE=512,
SPHERES=5,
),
config=CONFIG,
settings=wandb.Settings(
code_dir="./src/",
),
)
# create device
device = torch.device(wandb.config.DEVICE)
# seed random generators
pl.seed_everything(69420, workers=True)
# enable cudnn benchmarking
torch.backends.cudnn.benchmark = wandb.config.BENCHMARK
# 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
# transfer network to device
net.to(device=device)
# 1. Create transforms
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
tf_valid = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.ToFloat(max_value=255),
ToTensorV2(),
],
# Create network
net = UNet(
n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"],
batch_size=CONFIG["BATCH_SIZE"],
learning_rate=CONFIG["LEARNING_RATE"],
features=CONFIG["FEATURES"],
)
# 2. Create datasets
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
# log gradients and weights regularly
logger.watch(net, log="all")
# 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
# 3. Create data loaders
train_loader = DataLoader(
ds_train,
shuffle=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
val_loader = DataLoader(
ds_valid,
shuffle=False,
drop_last=True,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.RMSprop(
net.parameters(),
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
# save model.onxx
dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device)
torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
# log gradients and weights four time per epoch
wandb.watch(net, log_freq=100)
# print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
# wandb init log
wandb.log(
{
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
},
commit=False,
# Create the trainer
trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"],
# precision=16,
# auto_scale_batch_size="binsearch",
# auto_lr_find=True,
benchmark=CONFIG["BENCHMARK"],
val_check_interval=100,
callbacks=RichProgressBar(),
logger=logger,
log_every_n_steps=1,
)
try:
global_step = 0
for epoch in range(1, wandb.config.EPOCHS + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
# Training round
for step, (images, true_masks) in enumerate(train_loader):
assert images.shape[1] == net.n_channels, (
f"Network has been defined with {net.n_channels} input channels, "
f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
# transfer images to device
images = images.to(device=device)
true_masks = true_masks.unsqueeze(1).to(device=device)
# forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
# backward
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
# compute metrics
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar
pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()})
# log metrics
wandb.log(
{
"train/accuracy": accuracy,
"train/bce": train_loss,
"train/dice": dice,
"train/mae": mae,
}
)
global_step += 1
if global_step % 100 == 0:
# Evaluation round
net.eval()
accuracy = 0
val_loss = 0
dice = 0
mae = 0
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
for images, masks_true in val_loader:
# transfer images to device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
# forward
with torch.inference_mode():
masks_pred = net(images)
# compute metrics
val_loss += criterion(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bars
pbar2.update(images.shape[0])
pbar.refresh()
accuracy /= len(val_loader)
val_loss /= len(val_loader)
dice /= len(val_loader)
mae /= len(val_loader)
# save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
table.add_data(
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
)
# log validation metrics
wandb.log(
{
"val/predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice,
"val/mae": mae,
},
commit=False,
)
# update hyperparameters
net.train()
scheduler.step(dice)
# export model to onnx format when validation ends
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
wandb.run.log_artifact(artifact)
# stop wandb
wandb.run.finish()
trainer.tune(net)
trainer.fit(model=net)
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")
raise
# sapin de noel
# stop wandb
wandb.run.finish()

View file

@ -1,4 +1,4 @@
""" Parts of the U-Net model """
"""Parts of the U-Net model."""
import torch
import torch.nn as nn
@ -59,8 +59,14 @@ class Up(nn.Module):
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
x1 = F.pad(
input=x1,
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)

View file

@ -1,14 +1,38 @@
""" Full assembly of the parts to form the complete network """
import itertools
import albumentations as A
import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import wandb
from src.utils.dataset import SphereDataset
from utils.dice import dice_loss
from utils.paste import RandomPaste
from .blocks import *
class_labels = {
1: "sphere",
}
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
class UNet(pl.LightningModule):
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
# Hyperparameters
self.n_channels = n_channels
self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
# Network
self.inc = DoubleConv(n_channels, features[0])
self.downs = nn.ModuleList()
@ -28,6 +52,8 @@ class UNet(nn.Module):
def forward(self, x):
skips = []
x = x.to(self.device)
x = self.inc(x)
for down in self.downs:
@ -40,3 +66,224 @@ class UNet(nn.Module):
x = self.outc(x)
return x
def train_dataloader(self):
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
return DataLoader(
ds_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
return DataLoader(
ds_valid,
shuffle=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def training_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
# forward pass
masks_pred = self(images)
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
self.log_dict(
{
"train/accuracy": accuracy,
"train/dice": dice,
"train/dice_bin": dice_bin,
"train/bce": bce,
"train/mae": mae,
},
)
if batch_idx == 22000:
rows = []
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="train/predictions",
columns=columns,
data=rows,
)
except:
pass
return dict(
accuracy=accuracy,
loss=dice,
bce=bce,
mae=mae,
)
def validation_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
# forward pass
masks_pred = self(images)
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
rows = []
if batch_idx % 50 == 0 or dice < 0.1:
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
return dict(
accuracy=accuracy,
loss=dice,
dice_bin=dice_bin,
bce=bce,
mae=mae,
table_rows=rows,
)
def validation_epoch_end(self, validation_outputs):
# matrics unpacking
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
# table unpacking
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
rowss = [d["table_rows"] for d in validation_outputs]
rows = list(itertools.chain.from_iterable(rowss))
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="val/predictions",
columns=columns,
data=rows,
)
except:
pass
self.log_dict(
{
"val/accuracy": accuracy,
"val/dice": loss,
"val/dice_bin": dice_bin,
"val/bce": bce,
"val/mae": mae,
}
)
# export model to pth
torch.save(self.state_dict(), f"checkpoints/model.pth")
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("checkpoints/model.pth")
wandb.run.log_artifact(artifact)
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=self.learning_rate,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
return optimizer

View file

@ -1,80 +1,18 @@
import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all batches, or for a single mask.
def dice_score(inputs, targets, smooth=1, logits=True):
# comment out if your model contains a sigmoid or equivalent activation layer
if logits:
inputs = torch.sigmoid(inputs)
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
# flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
Raises:
ValueError: _description_
Returns:
float: _description_
"""
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
intersection = (inputs * targets).sum()
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all classes.
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Returns:
float: _description_
"""
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
"""Dice loss (objective to minimize) between 0 and 1.
Args:
input (Tensor): _description_
target (Tensor): _description_
multiclass (bool, optional): _description_. Defaults to False.
Returns:
float: _description_
"""
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
def dice_loss(inputs, targets, smooth=1, logits=True):
return 1 - dice_score(inputs, targets, smooth, logits)

View file

@ -1,9 +1,10 @@
import os
import random as rd
from pathlib import Path
import albumentations as A
import numpy as np
from PIL import Image, ImageEnhance
import torchvision.transforms as T
from PIL import Image
class RandomPaste(A.DualTransform):
@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform):
def __init__(
self,
nb,
path_paste_img_dir,
path_paste_mask_dir,
scale_range=(0.1, 0.2),
image_dir,
scale_range=(0.05, 0.25),
always_apply=True,
p=1.0,
):
super().__init__(always_apply, p)
self.path_paste_img_dir = path_paste_img_dir
self.path_paste_mask_dir = path_paste_mask_dir
self.images = []
self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
self.images.extend(list(Path(image_dir).glob("**/*.png")))
self.scale_range = scale_range
self.nb = nb
@ -38,105 +39,129 @@ class RandomPaste(A.DualTransform):
def targets_as_params(self):
return ["image"]
def apply(self, img, positions, paste_img, paste_mask, **params):
def apply(self, img, augmentations, paste_img, paste_mask, **params):
# convert img to Image, needed for `paste` function
img = Image.fromarray(img)
# copy paste_img and paste_mask
paste_mask = paste_mask.copy()
paste_img = paste_img.copy()
# paste spheres
for pos in positions:
img.paste(paste_img, pos, paste_mask)
for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations:
paste_img = T.functional.adjust_contrast(
paste_img,
contrast_factor=contrast,
)
paste_img = T.functional.adjust_brightness(
paste_img,
brightness_factor=brightness,
)
paste_img = T.functional.affine(
paste_img,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
interpolation=T.InterpolationMode.BICUBIC,
)
paste_img = T.functional.resize(
paste_img,
size=shape,
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask = T.functional.affine(
paste_mask,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask = T.functional.resize(
paste_mask,
size=shape,
interpolation=T.InterpolationMode.BICUBIC,
)
img.paste(paste_img, (x, y), paste_mask)
return np.asarray(img.convert("RGB"))
def apply_to_mask(self, mask, positions, paste_mask, **params):
def apply_to_mask(self, mask, augmentations, paste_mask, **params):
# convert mask to Image, needed for `paste` function
mask = Image.fromarray(mask)
# binarize the mask -> {0, 1}
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
# copy paste_img and paste_mask
paste_mask = paste_mask.copy()
# paste spheres
for pos in positions:
mask.paste(paste_mask, pos, paste_mask_bin)
for (x, y, shearx, sheary, shape, angle, _, _) in augmentations:
paste_mask = T.functional.affine(
paste_mask,
scale=0.95,
angle=angle,
translate=(0, 0),
shear=(shearx, sheary),
interpolation=T.InterpolationMode.BICUBIC,
)
paste_mask = T.functional.resize(
paste_mask,
size=shape,
interpolation=T.InterpolationMode.BICUBIC,
)
# binarize the mask -> {0, 1}
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
mask.paste(paste_mask, (x, y), paste_mask_bin)
return np.asarray(mask.convert("L"))
@staticmethod
def overlap(positions, x1, y1, w, h):
for x2, y2 in positions:
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
return True
return False
def get_params_dependent_on_targets(self, params):
# choose a random image inside the image folder
filename = rd.choice(os.listdir(self.path_paste_img_dir))
# choose a random image and its corresponding mask
img_path = rd.choice(self.images)
mask_path = img_path.parent.joinpath("MASK.PNG")
# load the "paste" image
paste_img = Image.open(
os.path.join(
self.path_paste_img_dir,
filename,
)
).convert("RGBA")
# load its respective mask
paste_mask = Image.open(
os.path.join(
self.path_paste_mask_dir,
filename,
)
).convert("LA")
# load the target image
# load images (w/ transparency)
paste_img = Image.open(img_path).convert("RGBA")
paste_mask = Image.open(mask_path).convert("LA")
target_img = params["image"]
# compute shapes
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
paste_shape = np.array(paste_img.size, dtype=np.uint)
# compute the minimum scaling to fit inside target image
# compute minimum scaling to fit inside target
min_scale = np.min(target_shape / paste_shape)
# randomize the relative scaling
scale = rd.uniform(*self.scale_range)
# rotate the image and its mask
angle = rd.uniform(0, 360)
paste_img = paste_img.rotate(angle, expand=True)
paste_mask = paste_mask.rotate(angle, expand=True)
# scale the "paste" image and its mask
paste_img = paste_img.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
paste_mask = paste_mask.resize(
tuple((paste_shape * min_scale * scale).astype(np.uint)),
resample=Image.Resampling.LANCZOS,
)
# update paste_shape after scaling
paste_shape = np.array(paste_img.size, dtype=np.uint)
# change brightness randomly
filter = ImageEnhance.Brightness(paste_img)
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
# generate some positions
positions = []
# generate augmentations
augmentations = []
NB = rd.randint(1, self.nb)
while len(positions) < NB:
x = rd.randint(0, target_shape[0] - paste_shape[0])
y = rd.randint(0, target_shape[1] - paste_shape[1])
while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max
scale = rd.uniform(*self.scale_range) * min_scale
shape = np.array(paste_shape * scale, dtype=np.uint)
x = rd.randint(0, target_shape[0] - shape[0])
y = rd.randint(0, target_shape[1] - shape[1])
# check for overlapping
if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]):
if RandomPaste.overlap(augmentations, x, y, shape[0], shape[1]):
continue
positions.append((x, y))
shearx = rd.uniform(-2, 2)
sheary = rd.uniform(-2, 2)
angle = rd.uniform(0, 360)
brightness = rd.uniform(0.8, 1.2)
contrast = rd.uniform(0.8, 1.2)
augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast))
params.update(
{
"positions": positions,
"augmentations": augmentations,
"paste_img": paste_img,
"paste_mask": paste_mask,
}
@ -144,5 +169,9 @@ class RandomPaste(A.DualTransform):
return params
def get_transform_init_args_names(self):
return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"
@staticmethod
def overlap(positions, x1, y1, w, h):
for x2, y2, _, _, _, _, _, _ in positions:
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
return True
return False