feat: add more config fields
Former-commit-id: 4b288d6c47aec460704348a003b48620327c218d [formerly 1af19e81906cd94483f7e9aadc7fdd72d35dc5e1] Former-commit-id: 605c529e3db6dde13cbbcf41f9517c3380b69b29
This commit is contained in:
parent
382b3ff91a
commit
8d6115afb2
|
@ -19,8 +19,13 @@ class Spheres(pl.LightningDataModule):
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
transforms = A.Compose(
|
transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.ToFloat(max_value=255),
|
A.Normalize(
|
||||||
ToTensorV2(),
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225],
|
||||||
|
max_pixel_value=255,
|
||||||
|
), # [0, 255] -> [0.0, 1.0] normalized
|
||||||
|
# A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(), # HWC -> CHW
|
||||||
],
|
],
|
||||||
bbox_params=A.BboxParams(
|
bbox_params=A.BboxParams(
|
||||||
format="pascal_voc",
|
format="pascal_voc",
|
||||||
|
@ -30,17 +35,18 @@ class Spheres(pl.LightningDataModule):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms)
|
dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
|
||||||
print(f"len(dataset)={len(dataset)}")
|
|
||||||
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
||||||
|
# dataset = Subset(dataset, list(range(20))) # somehow this allows to better utilize the gpu
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
|
persistent_workers=True,
|
||||||
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||||
num_workers=wandb.config.WORKERS,
|
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
21
src/train.py
21
src/train.py
|
@ -7,7 +7,6 @@ from pytorch_lightning.loggers import WandbLogger
|
||||||
import wandb
|
import wandb
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from mrcnn import MRCNNModule
|
from mrcnn import MRCNNModule
|
||||||
from unet import UNetModule
|
|
||||||
from utils import ArtifactLog, TableLog
|
from utils import ArtifactLog, TableLog
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -16,7 +15,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
project="U-Net",
|
project="Mask R-CNN",
|
||||||
config="wandb.yaml",
|
config="wandb.yaml",
|
||||||
settings=wandb.Settings(
|
settings=wandb.Settings(
|
||||||
code_dir="./src/",
|
code_dir="./src/",
|
||||||
|
@ -24,15 +23,9 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
# seed random generators
|
# seed random generators
|
||||||
pl.seed_everything(69420, workers=True)
|
pl.seed_everything(wandb.config.SEED, workers=True)
|
||||||
|
|
||||||
# Create network
|
|
||||||
# model = UNetModule(
|
|
||||||
# n_channels=wandb.config.N_CHANNELS,
|
|
||||||
# n_classes=wandb.config.N_CLASSES,
|
|
||||||
# features=wandb.config.FEATURES,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
# Create Network
|
||||||
model = MRCNNModule(
|
model = MRCNNModule(
|
||||||
hidden_layer_size=-1,
|
hidden_layer_size=-1,
|
||||||
n_classes=2,
|
n_classes=2,
|
||||||
|
@ -54,13 +47,13 @@ if __name__ == "__main__":
|
||||||
max_epochs=wandb.config.EPOCHS,
|
max_epochs=wandb.config.EPOCHS,
|
||||||
accelerator=wandb.config.DEVICE,
|
accelerator=wandb.config.DEVICE,
|
||||||
benchmark=wandb.config.BENCHMARK,
|
benchmark=wandb.config.BENCHMARK,
|
||||||
deterministic=True,
|
deterministic=wandb.config.DETERMINISTIC,
|
||||||
precision=16,
|
precision=wandb.config.PRECISION,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=5,
|
||||||
# val_check_interval=100,
|
# val_check_interval=100,
|
||||||
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||||
# profiler="simple",
|
# profiler="advanced",
|
||||||
num_sanity_val_steps=0,
|
num_sanity_val_steps=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from .module import UNetModule
|
|
|
@ -1,82 +0,0 @@
|
||||||
"""Parts of the U-Net model."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class DoubleConv(nn.Module):
|
|
||||||
"""(convolution => [BN] => ReLU) * 2"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not mid_channels:
|
|
||||||
mid_channels = out_channels
|
|
||||||
|
|
||||||
self.double_conv = nn.Sequential(
|
|
||||||
# first convolution
|
|
||||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(mid_channels),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
# second convolution
|
|
||||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(out_channels),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.double_conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Down(nn.Module):
|
|
||||||
"""Downscaling with maxpool then double conv"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.maxpool_conv = nn.Sequential(
|
|
||||||
nn.MaxPool2d(2),
|
|
||||||
DoubleConv(in_channels, out_channels),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.maxpool_conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Up(nn.Module):
|
|
||||||
"""Upscaling then double conv"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
|
||||||
self.conv = DoubleConv(in_channels, out_channels)
|
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
|
||||||
x1 = self.up(x1)
|
|
||||||
|
|
||||||
# input is CHW
|
|
||||||
diffY = x2.size()[2] - x1.size()[2]
|
|
||||||
diffX = x2.size()[3] - x1.size()[3]
|
|
||||||
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
|
|
||||||
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
|
|
||||||
|
|
||||||
x1 = F.pad(
|
|
||||||
input=x1,
|
|
||||||
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
|
|
||||||
)
|
|
||||||
|
|
||||||
x = torch.cat([x2, x1], dim=1)
|
|
||||||
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
|
|
||||||
class OutConv(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels):
|
|
||||||
super(OutConv, self).__init__()
|
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.conv(x)
|
|
|
@ -1,43 +0,0 @@
|
||||||
"""Full assembly of the parts to form the complete network."""
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .blocks import *
|
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
|
||||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
|
||||||
super(UNet, self).__init__()
|
|
||||||
|
|
||||||
# Network
|
|
||||||
self.inc = DoubleConv(n_channels, features[0])
|
|
||||||
|
|
||||||
self.downs = nn.ModuleList()
|
|
||||||
for i in range(len(features) - 1):
|
|
||||||
self.downs.append(
|
|
||||||
Down(*features[i : i + 2]),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
|
||||||
for i in range(len(features) - 1):
|
|
||||||
self.ups.append(
|
|
||||||
Up(*features[-1 - i : -3 - i : -1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.outc = OutConv(features[0], n_classes)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
skips = []
|
|
||||||
|
|
||||||
x = self.inc(x)
|
|
||||||
|
|
||||||
for down in self.downs:
|
|
||||||
skips.append(x)
|
|
||||||
x = down(x)
|
|
||||||
|
|
||||||
for up, skip in zip(self.ups, reversed(skips)):
|
|
||||||
x = up(x, skip)
|
|
||||||
|
|
||||||
x = self.outc(x)
|
|
||||||
|
|
||||||
return x
|
|
|
@ -1,117 +0,0 @@
|
||||||
"""Pytorch lightning wrapper for model."""
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
|
|
||||||
import wandb
|
|
||||||
from unet.model import UNet
|
|
||||||
from utils.dice import dice_loss
|
|
||||||
|
|
||||||
from .blocks import *
|
|
||||||
|
|
||||||
class_labels = {
|
|
||||||
1: "sphere",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UNetModule(pl.LightningModule):
|
|
||||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
|
||||||
super(UNetModule, self).__init__()
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
self.n_channels = n_channels
|
|
||||||
self.n_classes = n_classes
|
|
||||||
|
|
||||||
# log hyperparameters
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
# Network
|
|
||||||
self.model = UNet(n_channels, n_classes, features)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.model(x)
|
|
||||||
|
|
||||||
def shared_step(self, batch):
|
|
||||||
data, ground_truth = batch # unpacking
|
|
||||||
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
|
|
||||||
|
|
||||||
# forward pass, compute masks
|
|
||||||
prediction = self.model(data)
|
|
||||||
binary = (torch.sigmoid(prediction) > 0.5).half()
|
|
||||||
|
|
||||||
# compute metrics (in dictionnary)
|
|
||||||
metrics = {
|
|
||||||
"dice": dice_loss(prediction, ground_truth),
|
|
||||||
"dice_bin": dice_loss(binary, ground_truth, logits=False),
|
|
||||||
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
|
|
||||||
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
|
|
||||||
"accuracy": (ground_truth == binary).float().mean(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# wrap tensors in dictionnary
|
|
||||||
predictions = {
|
|
||||||
"linear": prediction,
|
|
||||||
"binary": binary,
|
|
||||||
}
|
|
||||||
|
|
||||||
return metrics, predictions
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
# compute metrics
|
|
||||||
metrics, _ = self.shared_step(batch)
|
|
||||||
|
|
||||||
# log metrics
|
|
||||||
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
|
||||||
|
|
||||||
return metrics["dice"]
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
# compute metrics
|
|
||||||
metrics, predictions = self.shared_step(batch)
|
|
||||||
|
|
||||||
# log metrics
|
|
||||||
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
|
||||||
|
|
||||||
return metrics, predictions
|
|
||||||
|
|
||||||
def validation_epoch_end(self, validation_outputs):
|
|
||||||
# unpacking
|
|
||||||
metricss = [v[0] for v in validation_outputs]
|
|
||||||
rowss = [v[1] for v in validation_outputs]
|
|
||||||
|
|
||||||
# metrics flattening
|
|
||||||
metrics = {
|
|
||||||
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
|
|
||||||
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
|
||||||
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
|
||||||
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
|
||||||
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# log metrics
|
|
||||||
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
|
||||||
|
|
||||||
# rows flattening
|
|
||||||
rows = list(itertools.chain.from_iterable(rowss))
|
|
||||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
|
||||||
|
|
||||||
# log table
|
|
||||||
wandb.log(
|
|
||||||
{
|
|
||||||
"val/predictions": wandb.Table(
|
|
||||||
columns=columns,
|
|
||||||
data=rows,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.RMSprop(
|
|
||||||
self.parameters(),
|
|
||||||
lr=wandb.config.LEARNING_RATE,
|
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
|
||||||
momentum=wandb.config.MOMENTUM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer
|
|
20
wandb.yaml
20
wandb.yaml
|
@ -17,11 +17,17 @@ AMP:
|
||||||
PIN_MEMORY:
|
PIN_MEMORY:
|
||||||
value: True
|
value: True
|
||||||
BENCHMARK:
|
BENCHMARK:
|
||||||
|
value: True
|
||||||
|
DETERMINISTIC:
|
||||||
value: False
|
value: False
|
||||||
|
PRECISION:
|
||||||
|
value: 16
|
||||||
|
SEED:
|
||||||
|
value: 69420
|
||||||
DEVICE:
|
DEVICE:
|
||||||
value: gpu
|
value: gpu
|
||||||
WORKERS:
|
WORKERS:
|
||||||
value: 1
|
value: 16
|
||||||
|
|
||||||
IMG_SIZE:
|
IMG_SIZE:
|
||||||
value: 512
|
value: 512
|
||||||
|
@ -29,17 +35,17 @@ SPHERES:
|
||||||
value: 3
|
value: 3
|
||||||
|
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 3
|
value: 10
|
||||||
TRAIN_BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 2 # 100
|
value: 8
|
||||||
VAL_BATCH_SIZE:
|
VAL_BATCH_SIZE:
|
||||||
value: 0 # 10
|
value: 0
|
||||||
PREFETCH_FACTOR:
|
PREFETCH_FACTOR:
|
||||||
value: 1
|
value: 2
|
||||||
|
|
||||||
LEARNING_RATE:
|
LEARNING_RATE:
|
||||||
value: 1.0e-4
|
value: 0.005
|
||||||
WEIGHT_DECAY:
|
WEIGHT_DECAY:
|
||||||
value: 1.0e-8
|
value: 0.0005
|
||||||
MOMENTUM:
|
MOMENTUM:
|
||||||
value: 0.9
|
value: 0.9
|
||||||
|
|
Loading…
Reference in a new issue