feat: add more config fields

Former-commit-id: 4b288d6c47aec460704348a003b48620327c218d [formerly 1af19e81906cd94483f7e9aadc7fdd72d35dc5e1]
Former-commit-id: 605c529e3db6dde13cbbcf41f9517c3380b69b29
This commit is contained in:
Laurent Fainsin 2022-08-26 11:10:49 +02:00
parent 382b3ff91a
commit 8d6115afb2
7 changed files with 32 additions and 270 deletions

View file

@ -19,8 +19,13 @@ class Spheres(pl.LightningDataModule):
def train_dataloader(self): def train_dataloader(self):
transforms = A.Compose( transforms = A.Compose(
[ [
A.ToFloat(max_value=255), A.Normalize(
ToTensorV2(), mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255,
), # [0, 255] -> [0.0, 1.0] normalized
# A.ToFloat(max_value=255),
ToTensorV2(), # HWC -> CHW
], ],
bbox_params=A.BboxParams( bbox_params=A.BboxParams(
format="pascal_voc", format="pascal_voc",
@ -30,17 +35,18 @@ class Spheres(pl.LightningDataModule):
), ),
) )
dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms) dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
print(f"len(dataset)={len(dataset)}")
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
# dataset = Subset(dataset, list(range(20))) # somehow this allows to better utilize the gpu
return DataLoader( return DataLoader(
dataset, dataset,
shuffle=True, shuffle=False,
persistent_workers=True,
prefetch_factor=wandb.config.PREFETCH_FACTOR, prefetch_factor=wandb.config.PREFETCH_FACTOR,
batch_size=wandb.config.TRAIN_BATCH_SIZE, batch_size=wandb.config.TRAIN_BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
num_workers=wandb.config.WORKERS,
collate_fn=collate_fn, collate_fn=collate_fn,
) )

View file

@ -7,7 +7,6 @@ from pytorch_lightning.loggers import WandbLogger
import wandb import wandb
from data import Spheres from data import Spheres
from mrcnn import MRCNNModule from mrcnn import MRCNNModule
from unet import UNetModule
from utils import ArtifactLog, TableLog from utils import ArtifactLog, TableLog
if __name__ == "__main__": if __name__ == "__main__":
@ -16,7 +15,7 @@ if __name__ == "__main__":
# setup wandb # setup wandb
logger = WandbLogger( logger = WandbLogger(
project="U-Net", project="Mask R-CNN",
config="wandb.yaml", config="wandb.yaml",
settings=wandb.Settings( settings=wandb.Settings(
code_dir="./src/", code_dir="./src/",
@ -24,15 +23,9 @@ if __name__ == "__main__":
) )
# seed random generators # seed random generators
pl.seed_everything(69420, workers=True) pl.seed_everything(wandb.config.SEED, workers=True)
# Create network
# model = UNetModule(
# n_channels=wandb.config.N_CHANNELS,
# n_classes=wandb.config.N_CLASSES,
# features=wandb.config.FEATURES,
# )
# Create Network
model = MRCNNModule( model = MRCNNModule(
hidden_layer_size=-1, hidden_layer_size=-1,
n_classes=2, n_classes=2,
@ -54,13 +47,13 @@ if __name__ == "__main__":
max_epochs=wandb.config.EPOCHS, max_epochs=wandb.config.EPOCHS,
accelerator=wandb.config.DEVICE, accelerator=wandb.config.DEVICE,
benchmark=wandb.config.BENCHMARK, benchmark=wandb.config.BENCHMARK,
deterministic=True, deterministic=wandb.config.DETERMINISTIC,
precision=16, precision=wandb.config.PRECISION,
logger=logger, logger=logger,
log_every_n_steps=1, log_every_n_steps=5,
# val_check_interval=100, # val_check_interval=100,
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()], callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
# profiler="simple", # profiler="advanced",
num_sanity_val_steps=0, num_sanity_val_steps=0,
) )

View file

@ -1 +0,0 @@
from .module import UNetModule

View file

@ -1,82 +0,0 @@
"""Parts of the U-Net model."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
# first convolution
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
# second convolution
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels),
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
x1 = F.pad(
input=x1,
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)

View file

@ -1,43 +0,0 @@
"""Full assembly of the parts to form the complete network."""
import torch.nn as nn
from .blocks import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNet, self).__init__()
# Network
self.inc = DoubleConv(n_channels, features[0])
self.downs = nn.ModuleList()
for i in range(len(features) - 1):
self.downs.append(
Down(*features[i : i + 2]),
)
self.ups = nn.ModuleList()
for i in range(len(features) - 1):
self.ups.append(
Up(*features[-1 - i : -3 - i : -1]),
)
self.outc = OutConv(features[0], n_classes)
def forward(self, x):
skips = []
x = self.inc(x)
for down in self.downs:
skips.append(x)
x = down(x)
for up, skip in zip(self.ups, reversed(skips)):
x = up(x, skip)
x = self.outc(x)
return x

View file

@ -1,117 +0,0 @@
"""Pytorch lightning wrapper for model."""
import itertools
import pytorch_lightning as pl
import wandb
from unet.model import UNet
from utils.dice import dice_loss
from .blocks import *
class_labels = {
1: "sphere",
}
class UNetModule(pl.LightningModule):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNetModule, self).__init__()
# Hyperparameters
self.n_channels = n_channels
self.n_classes = n_classes
# log hyperparameters
self.save_hyperparameters()
# Network
self.model = UNet(n_channels, n_classes, features)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
data, ground_truth = batch # unpacking
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
# forward pass, compute masks
prediction = self.model(data)
binary = (torch.sigmoid(prediction) > 0.5).half()
# compute metrics (in dictionnary)
metrics = {
"dice": dice_loss(prediction, ground_truth),
"dice_bin": dice_loss(binary, ground_truth, logits=False),
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
"accuracy": (ground_truth == binary).float().mean(),
}
# wrap tensors in dictionnary
predictions = {
"linear": prediction,
"binary": binary,
}
return metrics, predictions
def training_step(self, batch, batch_idx):
# compute metrics
metrics, _ = self.shared_step(batch)
# log metrics
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
return metrics["dice"]
def validation_step(self, batch, batch_idx):
# compute metrics
metrics, predictions = self.shared_step(batch)
# log metrics
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
return metrics, predictions
def validation_epoch_end(self, validation_outputs):
# unpacking
metricss = [v[0] for v in validation_outputs]
rowss = [v[1] for v in validation_outputs]
# metrics flattening
metrics = {
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
}
# log metrics
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
# rows flattening
rows = list(itertools.chain.from_iterable(rowss))
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
# log table
wandb.log(
{
"val/predictions": wandb.Table(
columns=columns,
data=rows,
)
}
)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
return optimizer

View file

@ -17,11 +17,17 @@ AMP:
PIN_MEMORY: PIN_MEMORY:
value: True value: True
BENCHMARK: BENCHMARK:
value: True
DETERMINISTIC:
value: False value: False
PRECISION:
value: 16
SEED:
value: 69420
DEVICE: DEVICE:
value: gpu value: gpu
WORKERS: WORKERS:
value: 1 value: 16
IMG_SIZE: IMG_SIZE:
value: 512 value: 512
@ -29,17 +35,17 @@ SPHERES:
value: 3 value: 3
EPOCHS: EPOCHS:
value: 3 value: 10
TRAIN_BATCH_SIZE: TRAIN_BATCH_SIZE:
value: 2 # 100 value: 8
VAL_BATCH_SIZE: VAL_BATCH_SIZE:
value: 0 # 10 value: 0
PREFETCH_FACTOR: PREFETCH_FACTOR:
value: 1 value: 2
LEARNING_RATE: LEARNING_RATE:
value: 1.0e-4 value: 0.005
WEIGHT_DECAY: WEIGHT_DECAY:
value: 1.0e-8 value: 0.0005
MOMENTUM: MOMENTUM:
value: 0.9 value: 0.9