refactor: code splitting

Former-commit-id: 7b293e392cc7d4135ef8562faece6f491c623718 [formerly 381a418ceab2cb7f367f07b9f3ea4f3c6a41ecac]
Former-commit-id: c2ddd57c4a3592c93640170a737506bb64b60864
This commit is contained in:
Laurent Fainsin 2022-07-08 16:06:58 +02:00
parent ebb213c565
commit 5f46efa5a1
9 changed files with 258 additions and 259 deletions

View file

@ -1 +1 @@
0f3136c724eea42fdf1ee15e721ef33604e9a46d
ac8ff07f541ae6d7cba729b20e0d04654c6018c9

1
src/data/__init__.py Normal file
View file

@ -0,0 +1 @@
from .dataloader import SyntheticSphere

50
src/data/dataloader.py Normal file
View file

@ -0,0 +1,50 @@
import albumentations as A
import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import wandb
from utils import RandomPaste
from .dataset import SphereDataset
class SyntheticSphere(pl.LightningDataModule):
def __init__(self):
super().__init__()
def train_dataloader(self):
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
return DataLoader(
ds_train,
shuffle=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
return DataLoader(
ds_valid,
shuffle=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)

View file

@ -1,7 +1,6 @@
import logging
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
@ -57,13 +56,17 @@ if __name__ == "__main__":
# log gradients and weights regularly
logger.watch(net, log="all")
# create checkpoint callback
checkpoint_callback = pl.ModelCheckpoint(
dirpath="checkpoints",
monitor="val/dice",
)
# Create the trainer
trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"],
# precision=16,
# auto_scale_batch_size="binsearch",
# auto_lr_find=True,
benchmark=CONFIG["BENCHMARK"],
val_check_interval=100,
callbacks=RichProgressBar(),
@ -71,12 +74,7 @@ if __name__ == "__main__":
log_every_n_steps=1,
)
try:
trainer.tune(net)
trainer.fit(model=net)
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")
raise
trainer.fit(model=net)
# stop wandb
wandb.run.finish()

View file

@ -1 +1 @@
from .model import UNet
from .module import UNetModule

View file

@ -1,37 +1,14 @@
""" Full assembly of the parts to form the complete network """
"""Full assembly of the parts to form the complete network."""
import itertools
import albumentations as A
import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import wandb
from src.utils.dataset import SphereDataset
from utils.dice import dice_loss
from utils.paste import RandomPaste
import torch.nn as nn
from .blocks import *
class_labels = {
1: "sphere",
}
class UNet(pl.LightningModule):
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
# Hyperparameters
self.n_channels = n_channels
self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
# Network
self.inc = DoubleConv(n_channels, features[0])
@ -66,224 +43,3 @@ class UNet(pl.LightningModule):
x = self.outc(x)
return x
def train_dataloader(self):
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
return DataLoader(
ds_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
return DataLoader(
ds_valid,
shuffle=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def training_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
# forward pass
masks_pred = self(images)
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
self.log_dict(
{
"train/accuracy": accuracy,
"train/dice": dice,
"train/dice_bin": dice_bin,
"train/bce": bce,
"train/mae": mae,
},
)
if batch_idx == 22000:
rows = []
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="train/predictions",
columns=columns,
data=rows,
)
except:
pass
return dict(
accuracy=accuracy,
loss=dice,
bce=bce,
mae=mae,
)
def validation_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
# forward pass
masks_pred = self(images)
# compute metrics
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
dice = dice_loss(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
rows = []
if batch_idx % 50 == 0 or dice > 0.9:
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
return dict(
accuracy=accuracy,
loss=dice,
dice_bin=dice_bin,
bce=bce,
mae=mae,
table_rows=rows,
)
def validation_epoch_end(self, validation_outputs):
# matrics unpacking
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
# table unpacking
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
rowss = [d["table_rows"] for d in validation_outputs]
rows = list(itertools.chain.from_iterable(rowss))
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="val/predictions",
columns=columns,
data=rows,
)
except:
pass
self.log_dict(
{
"val/accuracy": accuracy,
"val/dice": loss,
"val/dice_bin": dice_bin,
"val/bce": bce,
"val/mae": mae,
}
)
# export model to pth
torch.save(self.state_dict(), f"checkpoints/model.pth")
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("checkpoints/model.pth")
wandb.run.log_artifact(artifact)
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=self.learning_rate,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
return optimizer

193
src/unet/module.py Normal file
View file

@ -0,0 +1,193 @@
"""Pytorch lightning wrapper for model."""
import itertools
import pytorch_lightning as pl
import wandb
from unet.model import UNet
from utils.dice import dice_loss
from .blocks import *
class_labels = {
1: "sphere",
}
class UNetModule(pl.LightningModule):
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
super(UNetModule, self).__init__()
# Hyperparameters
self.n_channels = n_channels
self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = UNet(n_channels, n_classes, features)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
data, ground_truth = batch # unpacking
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
# forward pass, compute masks
prediction = self.model(data)
binary = (torch.sigmoid(prediction) > 0.5).float() # TODO: check if float necessary
# compute metrics (in dictionnary)
metrics = {
"dice": dice_loss(prediction, ground_truth),
"dice_bin": dice_loss(binary, ground_truth, logits=False),
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
"accuracy": (ground_truth == binary).float().mean(),
}
# wrap tensors in dictionnary
tensors = {
"data": data,
"ground_truth": ground_truth,
"prediction": prediction,
"binary": binary,
}
return metrics, tensors
def training_step(self, batch, batch_idx):
# compute metrics
metrics, tensors = self.shared_step(batch)
# log metrics
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
if batch_idx == 5000:
rows = []
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary
tensors["images"].cpu(),
tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(),
tensors["binary"]
.cpu()
.squeeze(1)
.int()
.numpy(), # TODO: check if .functions can be moved elsewhere
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
# log table
wandb.log(
{
"train/predictions": wandb.Table(
columns=columns,
data=rows,
)
}
)
return metrics["dice"]
def validation_step(self, batch, batch_idx):
metrics, tensors = self.shared_step(batch)
rows = []
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary
tensors["images"].cpu(),
tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(),
tensors["binary"]
.cpu()
.squeeze(1)
.int()
.numpy(), # TODO: check if .functions can be moved elsewhere
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
return metrics
def validation_epoch_end(self, validation_outputs):
# unpacking
metricss, rowss = validation_outputs
# metrics flattening
metrics = {
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(),
}
# log metrics
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
# rows flattening
rows = list(itertools.chain.from_iterable(rowss))
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
# log table
wandb.log(
{
"val/predictions": wandb.Table(
columns=columns,
data=rows,
)
}
)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=self.learning_rate,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
return optimizer

View file

@ -0,0 +1 @@
from .paste import RandomPaste