feat(WIP): switching to pytorch lightning

Former-commit-id: 0038dbca182717af8fc4bd846fd5be0e9fa70a9a [formerly eb5eb0717f8511bf49de8393bbdc66e727b930ff]
Former-commit-id: 540304228b146fe8e086bc4ccb770a13f84cbbcb
This commit is contained in:
Laurent Fainsin 2022-07-04 21:40:38 +02:00
parent d785a5c6be
commit 982dfe99d7
2 changed files with 202 additions and 243 deletions

View file

@ -1,16 +1,16 @@
import logging import logging
import albumentations as A import albumentations as A
import pytorch_lightning as pl
import torch import torch
import yaml import yaml
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb import wandb
from src.utils.dataset import SphereDataset from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.dice import dice_coeff
from utils.paste import RandomPaste from utils.paste import RandomPaste
class_labels = { class_labels = {
@ -22,7 +22,7 @@ if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb # setup wandb
wandb.init( logger = WandbLogger(
project="U-Net", project="U-Net",
config=dict( config=dict(
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
@ -36,7 +36,7 @@ if __name__ == "__main__":
AMP=True, AMP=True,
PIN_MEMORY=True, PIN_MEMORY=True,
BENCHMARK=True, BENCHMARK=True,
DEVICE="cuda", DEVICE="gpu",
WORKERS=8, WORKERS=8,
EPOCHS=5, EPOCHS=5,
BATCH_SIZE=16, BATCH_SIZE=16,
@ -51,18 +51,17 @@ if __name__ == "__main__":
), ),
) )
# create device # seed random generators
device = torch.device(wandb.config.DEVICE) pl.seed_everything(69420, workers=True)
# enable cudnn benchmarking
torch.backends.cudnn.benchmark = wandb.config.BENCHMARK
# 0. Create network # 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
# log the number of parameters of the model
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
# transfer network to device # log gradients and weights regularly
net.to(device=device) logger.watch(net, log="all")
# 1. Create transforms # 1. Create transforms
tf_train = A.Compose( tf_train = A.Compose(
@ -121,244 +120,38 @@ if __name__ == "__main__":
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
) )
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp # 4. Create the trainer
optimizer = torch.optim.RMSprop( trainer = pl.Trainer(
net.parameters(), max_epochs=wandb.config.EPOCHS,
lr=wandb.config.LEARNING_RATE, accelerator="gpu",
weight_decay=wandb.config.WEIGHT_DECAY, precision=16,
momentum=wandb.config.MOMENTUM, auto_scale_batch_size="binsearch",
benchmark=wandb.config.BENCHMARK,
val_check_interval=100,
) )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
# save model.onxx
dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device)
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("checkpoints/model-0.onnx")
wandb.run.log_artifact(artifact)
# log gradients and weights four time per epoch
wandb.watch(net, criterion, log_freq=100)
# print the config # print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
# wandb init log # # wandb init log
wandb.log( # wandb.log(
{ # {
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], # "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
}, # },
commit=False, # commit=False,
) # )
try: try:
for epoch in range(1, wandb.config.EPOCHS + 1): trainer.fit(
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: model=net,
train_dataloaders=train_loader,
# Training round val_dataloaders=val_loader,
for step, (images, true_masks) in enumerate(train_loader): test_dataloaders=test_loader,
assert images.shape[1] == net.n_channels, ( accelerator=wandb.config.DEVICE,
f"Network has been defined with {net.n_channels} input channels, " )
f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
# transfer images to device
images = images.to(device=device)
true_masks = true_masks.unsqueeze(1).to(device=device)
# forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
# backward
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
# compute metrics
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar
pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()})
# log metrics
wandb.log(
{
"epoch": epoch - 1 + step / len(train_loader),
"train/accuracy": accuracy,
"train/bce": train_loss,
"train/dice": dice,
"train/mae": mae,
}
)
if step and (step % 250 == 0 or step == len(train_loader)):
# Evaluation round
net.eval()
accuracy = 0
val_loss = 0
dice = 0
mae = 0
with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2:
for images, masks_true in val_loader:
# transfer images to device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
# forward
with torch.inference_mode():
masks_pred = net(images)
# compute metrics
val_loss += criterion(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar
pbar2.update(images.shape[0])
accuracy /= len(val_loader)
val_loss /= len(val_loader)
dice /= len(val_loader)
mae /= len(val_loader)
# save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
table.add_data(
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
)
# log validation metrics
wandb.log(
{
"val/predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice,
"val/mae": mae,
},
commit=False,
)
# update hyperparameters
net.train()
scheduler.step(dice)
# export model to onnx format when validation ends
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
wandb.run.log_artifact(artifact)
# testing round
net.eval()
accuracy = 0
val_loss = 0
dice = 0
mae = 0
with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3:
for images, masks_true in test_loader:
# transfer images to device
images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device)
# forward
with torch.inference_mode():
masks_pred = net(images)
# compute metrics
val_loss += criterion(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar
pbar3.update(images.shape[0])
accuracy /= len(test_loader)
val_loss /= len(test_loader)
dice /= len(test_loader)
mae /= len(test_loader)
# save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
table.add_data(
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
)
# log validation metrics
wandb.log(
{
"test/predictions": table,
"test/accuracy": accuracy,
"test/bce": val_loss,
"test/dice": dice,
"test/mae": mae,
},
commit=False,
)
# stop wandb
wandb.run.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")
raise raise
# sapin de noel # stop wandb
wandb.run.finish()

View file

@ -1,9 +1,21 @@
""" Full assembly of the parts to form the complete network """ """ Full assembly of the parts to form the complete network """
from xmlrpc.server import list_public_methods
import numpy as np
import pytorch_lightning as pl
import wandb
from utils.dice import dice_coeff
from .blocks import * from .blocks import *
class_labels = {
1: "sphere",
}
class UNet(nn.Module):
class UNet(pl.LightningModule):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNet, self).__init__() super(UNet, self).__init__()
self.n_channels = n_channels self.n_channels = n_channels
@ -26,7 +38,6 @@ class UNet(nn.Module):
self.outc = OutConv(features[0], n_classes) self.outc = OutConv(features[0], n_classes)
def forward(self, x): def forward(self, x):
skips = [] skips = []
x = self.inc(x) x = self.inc(x)
@ -41,3 +52,158 @@ class UNet(nn.Module):
x = self.outc(x) x = self.outc(x)
return x return x
@staticmethod
def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key):
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
table.add_data(
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
)
wandb.log(
{
log_key: table,
}
)
def training_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics
loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
wandb.log(
{
"train/accuracy": accuracy,
"train/bce": loss,
"train/dice": dice,
"train/mae": mae,
}
)
return loss, dice, accuracy, mae
def validation_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics
loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
return loss, dice, accuracy, mae
def validation_step_end(self, validation_outputs):
# unpacking
loss, dice, accuracy, mae = validation_outputs
optimizer = self.optimizers[0]
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
wandb.log(
{
"train/learning_rate": learning_rate,
"val/accuracy": accuracy,
"val/bce": loss,
"val/dice": dice,
"val/mae": mae,
}
)
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
def test_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics
loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
return loss, dice, accuracy, mae
def test_step_end(self, test_outputs):
# unpacking
list_loss, list_dice, list_accuracy, list_mae = test_outputs
# averaging
loss = np.mean(list_loss)
dice = np.mean(list_dice)
accuracy = np.mean(list_accuracy)
mae = np.mean(list_mae)
# get learning rate
optimizer = self.optimizers[0]
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
wandb.log(
{
"train/learning_rate": learning_rate,
"val/accuracy": accuracy,
"val/bce": loss,
"val/dice": dice,
"val/mae": mae,
}
)
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
"max",
patience=2,
)
return optimizer, scheduler