feat: add more config fields

Former-commit-id: 4b288d6c47aec460704348a003b48620327c218d [formerly 1af19e81906cd94483f7e9aadc7fdd72d35dc5e1]
Former-commit-id: 605c529e3db6dde13cbbcf41f9517c3380b69b29
This commit is contained in:
Laurent Fainsin 2022-08-26 11:10:49 +02:00
parent 382b3ff91a
commit 8d6115afb2
7 changed files with 32 additions and 270 deletions

View file

@ -19,8 +19,13 @@ class Spheres(pl.LightningDataModule):
def train_dataloader(self):
transforms = A.Compose(
[
A.ToFloat(max_value=255),
ToTensorV2(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255,
), # [0, 255] -> [0.0, 1.0] normalized
# A.ToFloat(max_value=255),
ToTensorV2(), # HWC -> CHW
],
bbox_params=A.BboxParams(
format="pascal_voc",
@ -30,17 +35,18 @@ class Spheres(pl.LightningDataModule):
),
)
dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms)
print(f"len(dataset)={len(dataset)}")
dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
# dataset = Subset(dataset, list(range(20))) # somehow this allows to better utilize the gpu
return DataLoader(
dataset,
shuffle=True,
shuffle=False,
persistent_workers=True,
prefetch_factor=wandb.config.PREFETCH_FACTOR,
batch_size=wandb.config.TRAIN_BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
num_workers=wandb.config.WORKERS,
collate_fn=collate_fn,
)

View file

@ -7,7 +7,6 @@ from pytorch_lightning.loggers import WandbLogger
import wandb
from data import Spheres
from mrcnn import MRCNNModule
from unet import UNetModule
from utils import ArtifactLog, TableLog
if __name__ == "__main__":
@ -16,7 +15,7 @@ if __name__ == "__main__":
# setup wandb
logger = WandbLogger(
project="U-Net",
project="Mask R-CNN",
config="wandb.yaml",
settings=wandb.Settings(
code_dir="./src/",
@ -24,15 +23,9 @@ if __name__ == "__main__":
)
# seed random generators
pl.seed_everything(69420, workers=True)
# Create network
# model = UNetModule(
# n_channels=wandb.config.N_CHANNELS,
# n_classes=wandb.config.N_CLASSES,
# features=wandb.config.FEATURES,
# )
pl.seed_everything(wandb.config.SEED, workers=True)
# Create Network
model = MRCNNModule(
hidden_layer_size=-1,
n_classes=2,
@ -54,13 +47,13 @@ if __name__ == "__main__":
max_epochs=wandb.config.EPOCHS,
accelerator=wandb.config.DEVICE,
benchmark=wandb.config.BENCHMARK,
deterministic=True,
precision=16,
deterministic=wandb.config.DETERMINISTIC,
precision=wandb.config.PRECISION,
logger=logger,
log_every_n_steps=1,
log_every_n_steps=5,
# val_check_interval=100,
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
# profiler="simple",
# profiler="advanced",
num_sanity_val_steps=0,
)

View file

@ -1 +0,0 @@
from .module import UNetModule

View file

@ -1,82 +0,0 @@
"""Parts of the U-Net model."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
# first convolution
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# second convolution
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels),
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
x1 = F.pad(
input=x1,
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)

View file

@ -1,43 +0,0 @@
"""Full assembly of the parts to form the complete network."""
import torch.nn as nn
from .blocks import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
# Network
self.inc = DoubleConv(n_channels, features[0])
self.downs = nn.ModuleList()
for i in range(len(features) - 1):
self.downs.append(
Down(*features[i : i + 2]),
)
self.ups = nn.ModuleList()
for i in range(len(features) - 1):
self.ups.append(
Up(*features[-1 - i : -3 - i : -1]),
)
self.outc = OutConv(features[0], n_classes)
def forward(self, x):
skips = []
x = self.inc(x)
for down in self.downs:
skips.append(x)
x = down(x)
for up, skip in zip(self.ups, reversed(skips)):
x = up(x, skip)
x = self.outc(x)
return x

View file

@ -1,117 +0,0 @@
"""Pytorch lightning wrapper for model."""
import itertools
import pytorch_lightning as pl
import wandb
from unet.model import UNet
from utils.dice import dice_loss
from .blocks import *
class_labels = {
1: "sphere",
}
class UNetModule(pl.LightningModule):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNetModule, self).__init__()
# Hyperparameters
self.n_channels = n_channels
self.n_classes = n_classes
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = UNet(n_channels, n_classes, features)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
data, ground_truth = batch # unpacking
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
# forward pass, compute masks
prediction = self.model(data)
binary = (torch.sigmoid(prediction) > 0.5).half()
# compute metrics (in dictionnary)
metrics = {
"dice": dice_loss(prediction, ground_truth),
"dice_bin": dice_loss(binary, ground_truth, logits=False),
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
"accuracy": (ground_truth == binary).float().mean(),
}
# wrap tensors in dictionnary
predictions = {
"linear": prediction,
"binary": binary,
}
return metrics, predictions
def training_step(self, batch, batch_idx):
# compute metrics
metrics, _ = self.shared_step(batch)
# log metrics
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
return metrics["dice"]
def validation_step(self, batch, batch_idx):
# compute metrics
metrics, predictions = self.shared_step(batch)
# log metrics
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
return metrics, predictions
def validation_epoch_end(self, validation_outputs):
# unpacking
metricss = [v[0] for v in validation_outputs]
rowss = [v[1] for v in validation_outputs]
# metrics flattening
metrics = {
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
}
# log metrics
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
# rows flattening
rows = list(itertools.chain.from_iterable(rowss))
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
# log table
wandb.log(
{
"val/predictions": wandb.Table(
columns=columns,
data=rows,
)
}
)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
return optimizer

View file

@ -17,11 +17,17 @@ AMP:
PIN_MEMORY:
value: True
BENCHMARK:
value: True
DETERMINISTIC:
value: False
PRECISION:
value: 16
SEED:
value: 69420
DEVICE:
value: gpu
WORKERS:
value: 1
value: 16
IMG_SIZE:
value: 512
@ -29,17 +35,17 @@ SPHERES:
value: 3
EPOCHS:
value: 3
value: 10
TRAIN_BATCH_SIZE:
value: 2 # 100
value: 8
VAL_BATCH_SIZE:
value: 0 # 10
value: 0
PREFETCH_FACTOR:
value: 1
value: 2
LEARNING_RATE:
value: 1.0e-4
value: 0.005
WEIGHT_DECAY:
value: 1.0e-8
value: 0.0005
MOMENTUM:
value: 0.9