feat: add more config fields
Former-commit-id: 4b288d6c47aec460704348a003b48620327c218d [formerly 1af19e81906cd94483f7e9aadc7fdd72d35dc5e1] Former-commit-id: 605c529e3db6dde13cbbcf41f9517c3380b69b29
This commit is contained in:
parent
382b3ff91a
commit
8d6115afb2
|
@ -19,8 +19,13 @@ class Spheres(pl.LightningDataModule):
|
|||
def train_dataloader(self):
|
||||
transforms = A.Compose(
|
||||
[
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
A.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
max_pixel_value=255,
|
||||
), # [0, 255] -> [0.0, 1.0] normalized
|
||||
# A.ToFloat(max_value=255),
|
||||
ToTensorV2(), # HWC -> CHW
|
||||
],
|
||||
bbox_params=A.BboxParams(
|
||||
format="pascal_voc",
|
||||
|
@ -30,17 +35,18 @@ class Spheres(pl.LightningDataModule):
|
|||
),
|
||||
)
|
||||
|
||||
dataset = RealDataset(root="/media/disk1/lfainsin/TEST_tmp_mrcnn/", transforms=transforms)
|
||||
print(f"len(dataset)={len(dataset)}")
|
||||
dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
|
||||
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
||||
# dataset = Subset(dataset, list(range(20))) # somehow this allows to better utilize the gpu
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
shuffle=False,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
|
21
src/train.py
21
src/train.py
|
@ -7,7 +7,6 @@ from pytorch_lightning.loggers import WandbLogger
|
|||
import wandb
|
||||
from data import Spheres
|
||||
from mrcnn import MRCNNModule
|
||||
from unet import UNetModule
|
||||
from utils import ArtifactLog, TableLog
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -16,7 +15,7 @@ if __name__ == "__main__":
|
|||
|
||||
# setup wandb
|
||||
logger = WandbLogger(
|
||||
project="U-Net",
|
||||
project="Mask R-CNN",
|
||||
config="wandb.yaml",
|
||||
settings=wandb.Settings(
|
||||
code_dir="./src/",
|
||||
|
@ -24,15 +23,9 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
# seed random generators
|
||||
pl.seed_everything(69420, workers=True)
|
||||
|
||||
# Create network
|
||||
# model = UNetModule(
|
||||
# n_channels=wandb.config.N_CHANNELS,
|
||||
# n_classes=wandb.config.N_CLASSES,
|
||||
# features=wandb.config.FEATURES,
|
||||
# )
|
||||
pl.seed_everything(wandb.config.SEED, workers=True)
|
||||
|
||||
# Create Network
|
||||
model = MRCNNModule(
|
||||
hidden_layer_size=-1,
|
||||
n_classes=2,
|
||||
|
@ -54,13 +47,13 @@ if __name__ == "__main__":
|
|||
max_epochs=wandb.config.EPOCHS,
|
||||
accelerator=wandb.config.DEVICE,
|
||||
benchmark=wandb.config.BENCHMARK,
|
||||
deterministic=True,
|
||||
precision=16,
|
||||
deterministic=wandb.config.DETERMINISTIC,
|
||||
precision=wandb.config.PRECISION,
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
log_every_n_steps=5,
|
||||
# val_check_interval=100,
|
||||
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||
# profiler="simple",
|
||||
# profiler="advanced",
|
||||
num_sanity_val_steps=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .module import UNetModule
|
|
@ -1,82 +0,0 @@
|
|||
"""Parts of the U-Net model."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
"""(convolution => [BN] => ReLU) * 2"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
|
||||
if not mid_channels:
|
||||
mid_channels = out_channels
|
||||
|
||||
self.double_conv = nn.Sequential(
|
||||
# first convolution
|
||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
# second convolution
|
||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.double_conv(x)
|
||||
|
||||
|
||||
class Down(nn.Module):
|
||||
"""Downscaling with maxpool then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self.maxpool_conv = nn.Sequential(
|
||||
nn.MaxPool2d(2),
|
||||
DoubleConv(in_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.maxpool_conv(x)
|
||||
|
||||
|
||||
class Up(nn.Module):
|
||||
"""Upscaling then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
|
||||
# input is CHW
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
|
||||
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
|
||||
|
||||
x1 = F.pad(
|
||||
input=x1,
|
||||
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
|
||||
)
|
||||
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class OutConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(OutConv, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
|
@ -1,43 +0,0 @@
|
|||
"""Full assembly of the parts to form the complete network."""
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import *
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||
super(UNet, self).__init__()
|
||||
|
||||
# Network
|
||||
self.inc = DoubleConv(n_channels, features[0])
|
||||
|
||||
self.downs = nn.ModuleList()
|
||||
for i in range(len(features) - 1):
|
||||
self.downs.append(
|
||||
Down(*features[i : i + 2]),
|
||||
)
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i in range(len(features) - 1):
|
||||
self.ups.append(
|
||||
Up(*features[-1 - i : -3 - i : -1]),
|
||||
)
|
||||
|
||||
self.outc = OutConv(features[0], n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
skips = []
|
||||
|
||||
x = self.inc(x)
|
||||
|
||||
for down in self.downs:
|
||||
skips.append(x)
|
||||
x = down(x)
|
||||
|
||||
for up, skip in zip(self.ups, reversed(skips)):
|
||||
x = up(x, skip)
|
||||
|
||||
x = self.outc(x)
|
||||
|
||||
return x
|
|
@ -1,117 +0,0 @@
|
|||
"""Pytorch lightning wrapper for model."""
|
||||
|
||||
import itertools
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import wandb
|
||||
from unet.model import UNet
|
||||
from utils.dice import dice_loss
|
||||
|
||||
from .blocks import *
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
}
|
||||
|
||||
|
||||
class UNetModule(pl.LightningModule):
|
||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||
super(UNetModule, self).__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.model = UNet(n_channels, n_classes, features)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def shared_step(self, batch):
|
||||
data, ground_truth = batch # unpacking
|
||||
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
|
||||
|
||||
# forward pass, compute masks
|
||||
prediction = self.model(data)
|
||||
binary = (torch.sigmoid(prediction) > 0.5).half()
|
||||
|
||||
# compute metrics (in dictionnary)
|
||||
metrics = {
|
||||
"dice": dice_loss(prediction, ground_truth),
|
||||
"dice_bin": dice_loss(binary, ground_truth, logits=False),
|
||||
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
|
||||
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
|
||||
"accuracy": (ground_truth == binary).float().mean(),
|
||||
}
|
||||
|
||||
# wrap tensors in dictionnary
|
||||
predictions = {
|
||||
"linear": prediction,
|
||||
"binary": binary,
|
||||
}
|
||||
|
||||
return metrics, predictions
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# compute metrics
|
||||
metrics, _ = self.shared_step(batch)
|
||||
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
return metrics["dice"]
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# compute metrics
|
||||
metrics, predictions = self.shared_step(batch)
|
||||
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
return metrics, predictions
|
||||
|
||||
def validation_epoch_end(self, validation_outputs):
|
||||
# unpacking
|
||||
metricss = [v[0] for v in validation_outputs]
|
||||
rowss = [v[1] for v in validation_outputs]
|
||||
|
||||
# metrics flattening
|
||||
metrics = {
|
||||
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
|
||||
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
||||
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
||||
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
||||
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
|
||||
}
|
||||
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
# rows flattening
|
||||
rows = list(itertools.chain.from_iterable(rowss))
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
|
||||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"val/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
|
||||
return optimizer
|
20
wandb.yaml
20
wandb.yaml
|
@ -17,11 +17,17 @@ AMP:
|
|||
PIN_MEMORY:
|
||||
value: True
|
||||
BENCHMARK:
|
||||
value: True
|
||||
DETERMINISTIC:
|
||||
value: False
|
||||
PRECISION:
|
||||
value: 16
|
||||
SEED:
|
||||
value: 69420
|
||||
DEVICE:
|
||||
value: gpu
|
||||
WORKERS:
|
||||
value: 1
|
||||
value: 16
|
||||
|
||||
IMG_SIZE:
|
||||
value: 512
|
||||
|
@ -29,17 +35,17 @@ SPHERES:
|
|||
value: 3
|
||||
|
||||
EPOCHS:
|
||||
value: 3
|
||||
value: 10
|
||||
TRAIN_BATCH_SIZE:
|
||||
value: 2 # 100
|
||||
value: 8
|
||||
VAL_BATCH_SIZE:
|
||||
value: 0 # 10
|
||||
value: 0
|
||||
PREFETCH_FACTOR:
|
||||
value: 1
|
||||
value: 2
|
||||
|
||||
LEARNING_RATE:
|
||||
value: 1.0e-4
|
||||
value: 0.005
|
||||
WEIGHT_DECAY:
|
||||
value: 1.0e-8
|
||||
value: 0.0005
|
||||
MOMENTUM:
|
||||
value: 0.9
|
||||
|
|
Loading…
Reference in a new issue