refactor: moved the table logging in a callback
Former-commit-id: 37fa7b0da4556417ee1665d9f0375e71ce075958 [formerly 9314a3dfee09b085bb0d125d178c9f589532c6f5] Former-commit-id: 9c8bef3f5c6560f26512bc6586f1337b9d7985f3
This commit is contained in:
parent
edcc3c7bb2
commit
4015dad491
|
@ -36,6 +36,19 @@ class Spheres(pl.LightningDataModule):
|
|||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
|
||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
|
||||
# return DataLoader(
|
||||
# dataset,
|
||||
# shuffle=True,
|
||||
# batch_size=8,
|
||||
# prefetch_factor=8,
|
||||
# num_workers=wandb.config.WORKERS,
|
||||
# pin_memory=wandb.config.PIN_MEMORY,
|
||||
# )
|
||||
|
||||
def val_dataloader(self):
|
||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
|
|
22
src/train.py
22
src/train.py
|
@ -1,17 +1,19 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
import wandb
|
||||
from data import Spheres
|
||||
from unet import UNetModule
|
||||
from utils import TableLog
|
||||
|
||||
CONFIG = {
|
||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||
"DIR_VALID_IMG": "//home/lilian/data_disk/lfainsin/test_split/",
|
||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real_split/",
|
||||
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
|
||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
|
||||
"FEATURES": [8, 16, 32, 64],
|
||||
"N_CHANNELS": 3,
|
||||
"N_CLASSES": 1,
|
||||
|
@ -19,9 +21,9 @@ CONFIG = {
|
|||
"PIN_MEMORY": True,
|
||||
"BENCHMARK": True,
|
||||
"DEVICE": "gpu",
|
||||
"WORKERS": 14,
|
||||
"EPOCHS": 1,
|
||||
"BATCH_SIZE": 16 * 3,
|
||||
"WORKERS": 8,
|
||||
"EPOCHS": 10,
|
||||
"BATCH_SIZE": 16,
|
||||
"LEARNING_RATE": 1e-4,
|
||||
"WEIGHT_DECAY": 1e-8,
|
||||
"MOMENTUM": 0.9,
|
||||
|
@ -54,12 +56,18 @@ if __name__ == "__main__":
|
|||
features=CONFIG["FEATURES"],
|
||||
)
|
||||
|
||||
# load checkpoint
|
||||
state_dict = torch.load("checkpoints/synth.pth")
|
||||
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# log gradients and weights regularly
|
||||
logger.watch(model, log="all")
|
||||
|
||||
# create checkpoint callback
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath="checkpoints",
|
||||
filename="model.ckpt",
|
||||
monitor="val/dice",
|
||||
)
|
||||
|
||||
|
@ -75,8 +83,8 @@ if __name__ == "__main__":
|
|||
# precision=16,
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
val_check_interval=25,
|
||||
callbacks=[RichProgressBar(), checkpoint_callback, TableLog()],
|
||||
)
|
||||
|
||||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
|
|
@ -52,105 +52,30 @@ class UNetModule(pl.LightningModule):
|
|||
}
|
||||
|
||||
# wrap tensors in dictionnary
|
||||
tensors = {
|
||||
"data": data,
|
||||
"ground_truth": ground_truth,
|
||||
"prediction": prediction,
|
||||
predictions = {
|
||||
"linear": prediction,
|
||||
"binary": binary,
|
||||
}
|
||||
|
||||
return metrics, tensors
|
||||
return metrics, predictions
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# compute metrics
|
||||
metrics, tensors = self.shared_step(batch)
|
||||
metrics, _ = self.shared_step(batch)
|
||||
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
if batch_idx == 5000:
|
||||
rows = []
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip( # TODO: use comprehension list to zip the dictionnary
|
||||
tensors["data"].cpu(),
|
||||
tensors["ground_truth"].cpu(),
|
||||
tensors["prediction"].cpu(),
|
||||
tensors["binary"]
|
||||
.cpu()
|
||||
.squeeze(1)
|
||||
.int()
|
||||
.numpy(), # TODO: check if .functions can be moved elsewhere
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
|
||||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"train/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return metrics["dice"]
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
metrics, tensors = self.shared_step(batch)
|
||||
# compute metrics
|
||||
metrics, predictions = self.shared_step(batch)
|
||||
|
||||
rows = []
|
||||
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip( # TODO: use comprehension list to zip the dictionnary
|
||||
tensors["data"].cpu(),
|
||||
tensors["ground_truth"].cpu(),
|
||||
tensors["prediction"].cpu(),
|
||||
tensors["binary"]
|
||||
.cpu()
|
||||
.squeeze(1)
|
||||
.int()
|
||||
.numpy(), # TODO: check if .functions can be moved elsewhere
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
# log metrics
|
||||
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
||||
|
||||
return metrics, rows
|
||||
return metrics, predictions
|
||||
|
||||
def validation_epoch_end(self, validation_outputs):
|
||||
# unpacking
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
from .callback import TableLog
|
||||
from .paste import RandomPaste
|
||||
|
|
64
src/utils/callback.py
Normal file
64
src/utils/callback.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from torch import tensor
|
||||
|
||||
import wandb
|
||||
|
||||
columns = [
|
||||
"ID",
|
||||
"image",
|
||||
"ground truth",
|
||||
"prediction",
|
||||
"dice",
|
||||
"dice_bin",
|
||||
]
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
}
|
||||
|
||||
|
||||
class TableLog(Callback):
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
self.rows = []
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# unpacking
|
||||
images, ground_truth = batch
|
||||
metrics, predictions = outputs
|
||||
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
ground_truth.cpu(),
|
||||
predictions["linear"].cpu(),
|
||||
predictions["binary"].cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
self.rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"val/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=self.rows,
|
||||
)
|
||||
}
|
||||
)
|
Loading…
Reference in a new issue