refactor: code splitting
Former-commit-id: 7b293e392cc7d4135ef8562faece6f491c623718 [formerly 381a418ceab2cb7f367f07b9f3ea4f3c6a41ecac] Former-commit-id: c2ddd57c4a3592c93640170a737506bb64b60864
This commit is contained in:
parent
ebb213c565
commit
5f46efa5a1
|
@ -1 +1 @@
|
|||
0f3136c724eea42fdf1ee15e721ef33604e9a46d
|
||||
ac8ff07f541ae6d7cba729b20e0d04654c6018c9
|
1
src/data/__init__.py
Normal file
1
src/data/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .dataloader import SyntheticSphere
|
50
src/data/dataloader.py
Normal file
50
src/data/dataloader.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
from utils import RandomPaste
|
||||
|
||||
from .dataset import SphereDataset
|
||||
|
||||
|
||||
class SyntheticSphere(pl.LightningDataModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def train_dataloader(self):
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
|
||||
return DataLoader(
|
||||
ds_train,
|
||||
shuffle=True,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
|
||||
return DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
16
src/train.py
16
src/train.py
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
|
@ -57,13 +56,17 @@ if __name__ == "__main__":
|
|||
# log gradients and weights regularly
|
||||
logger.watch(net, log="all")
|
||||
|
||||
# create checkpoint callback
|
||||
checkpoint_callback = pl.ModelCheckpoint(
|
||||
dirpath="checkpoints",
|
||||
monitor="val/dice",
|
||||
)
|
||||
|
||||
# Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=CONFIG["EPOCHS"],
|
||||
accelerator=CONFIG["DEVICE"],
|
||||
# precision=16,
|
||||
# auto_scale_batch_size="binsearch",
|
||||
# auto_lr_find=True,
|
||||
benchmark=CONFIG["BENCHMARK"],
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
|
@ -71,12 +74,7 @@ if __name__ == "__main__":
|
|||
log_every_n_steps=1,
|
||||
)
|
||||
|
||||
try:
|
||||
trainer.tune(net)
|
||||
trainer.fit(model=net)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||
raise
|
||||
trainer.fit(model=net)
|
||||
|
||||
# stop wandb
|
||||
wandb.run.finish()
|
||||
|
|
|
@ -1 +1 @@
|
|||
from .model import UNet
|
||||
from .module import UNetModule
|
||||
|
|
|
@ -1,37 +1,14 @@
|
|||
""" Full assembly of the parts to form the complete network """
|
||||
"""Full assembly of the parts to form the complete network."""
|
||||
|
||||
import itertools
|
||||
|
||||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
from src.utils.dataset import SphereDataset
|
||||
from utils.dice import dice_loss
|
||||
from utils.paste import RandomPaste
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import *
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
}
|
||||
|
||||
|
||||
class UNet(pl.LightningModule):
|
||||
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||
super(UNet, self).__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.inc = DoubleConv(n_channels, features[0])
|
||||
|
||||
|
@ -66,224 +43,3 @@ class UNet(pl.LightningModule):
|
|||
x = self.outc(x)
|
||||
|
||||
return x
|
||||
|
||||
def train_dataloader(self):
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
|
||||
return DataLoader(
|
||||
ds_train,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
|
||||
return DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
images, masks_true = batch
|
||||
masks_true = masks_true.unsqueeze(1)
|
||||
|
||||
# forward pass
|
||||
masks_pred = self(images)
|
||||
|
||||
# compute metrics
|
||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||
dice = dice_loss(masks_pred, masks_true)
|
||||
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
|
||||
self.log_dict(
|
||||
{
|
||||
"train/accuracy": accuracy,
|
||||
"train/dice": dice,
|
||||
"train/dice_bin": dice_bin,
|
||||
"train/bce": bce,
|
||||
"train/mae": mae,
|
||||
},
|
||||
)
|
||||
|
||||
if batch_idx == 22000:
|
||||
rows = []
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
masks_true.cpu(),
|
||||
masks_pred.cpu(),
|
||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
dice,
|
||||
dice_bin,
|
||||
]
|
||||
)
|
||||
|
||||
# logging
|
||||
try: # required by autofinding, logger replaced by dummy
|
||||
self.logger.log_table(
|
||||
key="train/predictions",
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return dict(
|
||||
accuracy=accuracy,
|
||||
loss=dice,
|
||||
bce=bce,
|
||||
mae=mae,
|
||||
)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
images, masks_true = batch
|
||||
masks_true = masks_true.unsqueeze(1)
|
||||
|
||||
# forward pass
|
||||
masks_pred = self(images)
|
||||
|
||||
# compute metrics
|
||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||
dice = dice_loss(masks_pred, masks_true)
|
||||
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
|
||||
rows = []
|
||||
if batch_idx % 50 == 0 or dice > 0.9:
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
masks_true.cpu(),
|
||||
masks_pred.cpu(),
|
||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
dice,
|
||||
dice_bin,
|
||||
]
|
||||
)
|
||||
|
||||
return dict(
|
||||
accuracy=accuracy,
|
||||
loss=dice,
|
||||
dice_bin=dice_bin,
|
||||
bce=bce,
|
||||
mae=mae,
|
||||
table_rows=rows,
|
||||
)
|
||||
|
||||
def validation_epoch_end(self, validation_outputs):
|
||||
# matrics unpacking
|
||||
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
||||
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
|
||||
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
||||
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
|
||||
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
||||
|
||||
# table unpacking
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
rowss = [d["table_rows"] for d in validation_outputs]
|
||||
rows = list(itertools.chain.from_iterable(rowss))
|
||||
|
||||
# logging
|
||||
try: # required by autofinding, logger replaced by dummy
|
||||
self.logger.log_table(
|
||||
key="val/predictions",
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.log_dict(
|
||||
{
|
||||
"val/accuracy": accuracy,
|
||||
"val/dice": loss,
|
||||
"val/dice_bin": dice_bin,
|
||||
"val/bce": bce,
|
||||
"val/mae": mae,
|
||||
}
|
||||
)
|
||||
|
||||
# export model to pth
|
||||
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file("checkpoints/model.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# export model to onnx
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file("checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
|
193
src/unet/module.py
Normal file
193
src/unet/module.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
"""Pytorch lightning wrapper for model."""
|
||||
|
||||
import itertools
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import wandb
|
||||
from unet.model import UNet
|
||||
from utils.dice import dice_loss
|
||||
|
||||
from .blocks import *
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
}
|
||||
|
||||
|
||||
class UNetModule(pl.LightningModule):
|
||||
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
||||
super(UNetModule, self).__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.model = UNet(n_channels, n_classes, features)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def shared_step(self, batch):
|
||||
data, ground_truth = batch # unpacking
|
||||
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
|
||||
|
||||
# forward pass, compute masks
|
||||
prediction = self.model(data)
|
||||
binary = (torch.sigmoid(prediction) > 0.5).float() # TODO: check if float necessary
|
||||
|
||||
# compute metrics (in dictionnary)
|
||||
metrics = {
|
||||
"dice": dice_loss(prediction, ground_truth),
|
||||
"dice_bin": dice_loss(binary, ground_truth, logits=False),
|
||||
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
|
||||
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
|
||||
"accuracy": (ground_truth == binary).float().mean(),
|
||||
}
|
||||
|
||||
# wrap tensors in dictionnary
|
||||
tensors = {
|
||||
"data": data,
|
||||
"ground_truth": ground_truth,
|
||||
"prediction": prediction,
|
||||
"binary": binary,
|
||||
}
|
||||
|
||||
return metrics, tensors
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# compute metrics
|
||||
metrics, tensors = self.shared_step(batch)
|
||||
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
if batch_idx == 5000:
|
||||
rows = []
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip( # TODO: use comprehension list to zip the dictionnary
|
||||
tensors["images"].cpu(),
|
||||
tensors["ground_truth"].cpu(),
|
||||
tensors["prediction"].cpu(),
|
||||
tensors["binary"]
|
||||
.cpu()
|
||||
.squeeze(1)
|
||||
.int()
|
||||
.numpy(), # TODO: check if .functions can be moved elsewhere
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
|
||||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"train/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return metrics["dice"]
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
metrics, tensors = self.shared_step(batch)
|
||||
|
||||
rows = []
|
||||
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip( # TODO: use comprehension list to zip the dictionnary
|
||||
tensors["images"].cpu(),
|
||||
tensors["ground_truth"].cpu(),
|
||||
tensors["prediction"].cpu(),
|
||||
tensors["binary"]
|
||||
.cpu()
|
||||
.squeeze(1)
|
||||
.int()
|
||||
.numpy(), # TODO: check if .functions can be moved elsewhere
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def validation_epoch_end(self, validation_outputs):
|
||||
# unpacking
|
||||
metricss, rowss = validation_outputs
|
||||
|
||||
# metrics flattening
|
||||
metrics = {
|
||||
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
|
||||
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
||||
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
||||
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
||||
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(),
|
||||
}
|
||||
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
# rows flattening
|
||||
rows = list(itertools.chain.from_iterable(rowss))
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
|
||||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"val/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
|
||||
return optimizer
|
|
@ -0,0 +1 @@
|
|||
from .paste import RandomPaste
|
Loading…
Reference in a new issue