refactor: moved the table logging in a callback

Former-commit-id: 37fa7b0da4556417ee1665d9f0375e71ce075958 [formerly 9314a3dfee09b085bb0d125d178c9f589532c6f5]
Former-commit-id: 9c8bef3f5c6560f26512bc6586f1337b9d7985f3
This commit is contained in:
Laurent Fainsin 2022-07-11 14:41:18 +02:00
parent edcc3c7bb2
commit 4015dad491
5 changed files with 102 additions and 91 deletions

View file

@ -36,6 +36,19 @@ class Spheres(pl.LightningDataModule):
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
) )
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
# return DataLoader(
# dataset,
# shuffle=True,
# batch_size=8,
# prefetch_factor=8,
# num_workers=wandb.config.WORKERS,
# pin_memory=wandb.config.PIN_MEMORY,
# )
def val_dataloader(self): def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))

View file

@ -1,17 +1,19 @@
import logging import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
import wandb import wandb
from data import Spheres from data import Spheres
from unet import UNetModule from unet import UNetModule
from utils import TableLog
CONFIG = { CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "//home/lilian/data_disk/lfainsin/test_split/", "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real_split/", "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
"FEATURES": [8, 16, 32, 64], "FEATURES": [8, 16, 32, 64],
"N_CHANNELS": 3, "N_CHANNELS": 3,
"N_CLASSES": 1, "N_CLASSES": 1,
@ -19,9 +21,9 @@ CONFIG = {
"PIN_MEMORY": True, "PIN_MEMORY": True,
"BENCHMARK": True, "BENCHMARK": True,
"DEVICE": "gpu", "DEVICE": "gpu",
"WORKERS": 14, "WORKERS": 8,
"EPOCHS": 1, "EPOCHS": 10,
"BATCH_SIZE": 16 * 3, "BATCH_SIZE": 16,
"LEARNING_RATE": 1e-4, "LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8, "WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9, "MOMENTUM": 0.9,
@ -54,12 +56,18 @@ if __name__ == "__main__":
features=CONFIG["FEATURES"], features=CONFIG["FEATURES"],
) )
# load checkpoint
state_dict = torch.load("checkpoints/synth.pth")
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
model.load_state_dict(state_dict)
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(model, log="all") logger.watch(model, log="all")
# create checkpoint callback # create checkpoint callback
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints", dirpath="checkpoints",
filename="model.ckpt",
monitor="val/dice", monitor="val/dice",
) )
@ -75,8 +83,8 @@ if __name__ == "__main__":
# precision=16, # precision=16,
logger=logger, logger=logger,
log_every_n_steps=1, log_every_n_steps=1,
val_check_interval=100, val_check_interval=25,
callbacks=RichProgressBar(), callbacks=[RichProgressBar(), checkpoint_callback, TableLog()],
) )
trainer.fit(model=model, datamodule=datamodule) trainer.fit(model=model, datamodule=datamodule)

View file

@ -52,105 +52,30 @@ class UNetModule(pl.LightningModule):
} }
# wrap tensors in dictionnary # wrap tensors in dictionnary
tensors = { predictions = {
"data": data, "linear": prediction,
"ground_truth": ground_truth,
"prediction": prediction,
"binary": binary, "binary": binary,
} }
return metrics, tensors return metrics, predictions
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# compute metrics # compute metrics
metrics, tensors = self.shared_step(batch) metrics, _ = self.shared_step(batch)
# log metrics # log metrics
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()])) self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
if batch_idx == 5000:
rows = []
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary
tensors["data"].cpu(),
tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(),
tensors["binary"]
.cpu()
.squeeze(1)
.int()
.numpy(), # TODO: check if .functions can be moved elsewhere
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
# log table
wandb.log(
{
"train/predictions": wandb.Table(
columns=columns,
data=rows,
)
}
)
return metrics["dice"] return metrics["dice"]
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
metrics, tensors = self.shared_step(batch) # compute metrics
metrics, predictions = self.shared_step(batch)
rows = [] # log metrics
if batch_idx % 50 == 0 or metrics["dice"] > 0.9: self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary
tensors["data"].cpu(),
tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(),
tensors["binary"]
.cpu()
.squeeze(1)
.int()
.numpy(), # TODO: check if .functions can be moved elsewhere
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
return metrics, rows return metrics, predictions
def validation_epoch_end(self, validation_outputs): def validation_epoch_end(self, validation_outputs):
# unpacking # unpacking

View file

@ -1 +1,2 @@
from .callback import TableLog
from .paste import RandomPaste from .paste import RandomPaste

64
src/utils/callback.py Normal file
View file

@ -0,0 +1,64 @@
from pytorch_lightning.callbacks import Callback
from torch import tensor
import wandb
columns = [
"ID",
"image",
"ground truth",
"prediction",
"dice",
"dice_bin",
]
class_labels = {
1: "sphere",
}
class TableLog(Callback):
def on_validation_epoch_start(self, trainer, pl_module):
self.rows = []
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking
images, ground_truth = batch
metrics, predictions = outputs
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
ground_truth.cpu(),
predictions["linear"].cpu(),
predictions["binary"].cpu().squeeze(1).int().numpy(),
)
):
self.rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
def on_validation_epoch_end(self, trainer, pl_module):
# log table
wandb.log(
{
"val/predictions": wandb.Table(
columns=columns,
data=self.rows,
)
}
)