mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
refactor: moved the table logging in a callback
Former-commit-id: 37fa7b0da4556417ee1665d9f0375e71ce075958 [formerly 9314a3dfee09b085bb0d125d178c9f589532c6f5] Former-commit-id: 9c8bef3f5c6560f26512bc6586f1337b9d7985f3
This commit is contained in:
parent
edcc3c7bb2
commit
4015dad491
|
@ -36,6 +36,19 @@ class Spheres(pl.LightningDataModule):
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
|
||||||
|
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
|
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||||
|
|
||||||
|
# return DataLoader(
|
||||||
|
# dataset,
|
||||||
|
# shuffle=True,
|
||||||
|
# batch_size=8,
|
||||||
|
# prefetch_factor=8,
|
||||||
|
# num_workers=wandb.config.WORKERS,
|
||||||
|
# pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
# )
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||||
|
|
22
src/train.py
22
src/train.py
|
@ -1,17 +1,19 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from unet import UNetModule
|
from unet import UNetModule
|
||||||
|
from utils import TableLog
|
||||||
|
|
||||||
CONFIG = {
|
CONFIG = {
|
||||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||||
"DIR_VALID_IMG": "//home/lilian/data_disk/lfainsin/test_split/",
|
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
|
||||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real_split/",
|
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
|
||||||
"FEATURES": [8, 16, 32, 64],
|
"FEATURES": [8, 16, 32, 64],
|
||||||
"N_CHANNELS": 3,
|
"N_CHANNELS": 3,
|
||||||
"N_CLASSES": 1,
|
"N_CLASSES": 1,
|
||||||
|
@ -19,9 +21,9 @@ CONFIG = {
|
||||||
"PIN_MEMORY": True,
|
"PIN_MEMORY": True,
|
||||||
"BENCHMARK": True,
|
"BENCHMARK": True,
|
||||||
"DEVICE": "gpu",
|
"DEVICE": "gpu",
|
||||||
"WORKERS": 14,
|
"WORKERS": 8,
|
||||||
"EPOCHS": 1,
|
"EPOCHS": 10,
|
||||||
"BATCH_SIZE": 16 * 3,
|
"BATCH_SIZE": 16,
|
||||||
"LEARNING_RATE": 1e-4,
|
"LEARNING_RATE": 1e-4,
|
||||||
"WEIGHT_DECAY": 1e-8,
|
"WEIGHT_DECAY": 1e-8,
|
||||||
"MOMENTUM": 0.9,
|
"MOMENTUM": 0.9,
|
||||||
|
@ -54,12 +56,18 @@ if __name__ == "__main__":
|
||||||
features=CONFIG["FEATURES"],
|
features=CONFIG["FEATURES"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# load checkpoint
|
||||||
|
state_dict = torch.load("checkpoints/synth.pth")
|
||||||
|
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(model, log="all")
|
logger.watch(model, log="all")
|
||||||
|
|
||||||
# create checkpoint callback
|
# create checkpoint callback
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
dirpath="checkpoints",
|
dirpath="checkpoints",
|
||||||
|
filename="model.ckpt",
|
||||||
monitor="val/dice",
|
monitor="val/dice",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,8 +83,8 @@ if __name__ == "__main__":
|
||||||
# precision=16,
|
# precision=16,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=1,
|
||||||
val_check_interval=100,
|
val_check_interval=25,
|
||||||
callbacks=RichProgressBar(),
|
callbacks=[RichProgressBar(), checkpoint_callback, TableLog()],
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model=model, datamodule=datamodule)
|
trainer.fit(model=model, datamodule=datamodule)
|
||||||
|
|
|
@ -52,105 +52,30 @@ class UNetModule(pl.LightningModule):
|
||||||
}
|
}
|
||||||
|
|
||||||
# wrap tensors in dictionnary
|
# wrap tensors in dictionnary
|
||||||
tensors = {
|
predictions = {
|
||||||
"data": data,
|
"linear": prediction,
|
||||||
"ground_truth": ground_truth,
|
|
||||||
"prediction": prediction,
|
|
||||||
"binary": binary,
|
"binary": binary,
|
||||||
}
|
}
|
||||||
|
|
||||||
return metrics, tensors
|
return metrics, predictions
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# compute metrics
|
# compute metrics
|
||||||
metrics, tensors = self.shared_step(batch)
|
metrics, _ = self.shared_step(batch)
|
||||||
|
|
||||||
# log metrics
|
# log metrics
|
||||||
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
||||||
|
|
||||||
if batch_idx == 5000:
|
|
||||||
rows = []
|
|
||||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
|
||||||
zip( # TODO: use comprehension list to zip the dictionnary
|
|
||||||
tensors["data"].cpu(),
|
|
||||||
tensors["ground_truth"].cpu(),
|
|
||||||
tensors["prediction"].cpu(),
|
|
||||||
tensors["binary"]
|
|
||||||
.cpu()
|
|
||||||
.squeeze(1)
|
|
||||||
.int()
|
|
||||||
.numpy(), # TODO: check if .functions can be moved elsewhere
|
|
||||||
)
|
|
||||||
):
|
|
||||||
rows.append(
|
|
||||||
[
|
|
||||||
i,
|
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
|
||||||
pred,
|
|
||||||
masks={
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_bin,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
metrics["dice"],
|
|
||||||
metrics["dice_bin"],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# log table
|
|
||||||
wandb.log(
|
|
||||||
{
|
|
||||||
"train/predictions": wandb.Table(
|
|
||||||
columns=columns,
|
|
||||||
data=rows,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics["dice"]
|
return metrics["dice"]
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
metrics, tensors = self.shared_step(batch)
|
# compute metrics
|
||||||
|
metrics, predictions = self.shared_step(batch)
|
||||||
|
|
||||||
rows = []
|
# log metrics
|
||||||
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
|
||||||
zip( # TODO: use comprehension list to zip the dictionnary
|
|
||||||
tensors["data"].cpu(),
|
|
||||||
tensors["ground_truth"].cpu(),
|
|
||||||
tensors["prediction"].cpu(),
|
|
||||||
tensors["binary"]
|
|
||||||
.cpu()
|
|
||||||
.squeeze(1)
|
|
||||||
.int()
|
|
||||||
.numpy(), # TODO: check if .functions can be moved elsewhere
|
|
||||||
)
|
|
||||||
):
|
|
||||||
rows.append(
|
|
||||||
[
|
|
||||||
i,
|
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
|
||||||
pred,
|
|
||||||
masks={
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_bin,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
metrics["dice"],
|
|
||||||
metrics["dice_bin"],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics, rows
|
return metrics, predictions
|
||||||
|
|
||||||
def validation_epoch_end(self, validation_outputs):
|
def validation_epoch_end(self, validation_outputs):
|
||||||
# unpacking
|
# unpacking
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
|
from .callback import TableLog
|
||||||
from .paste import RandomPaste
|
from .paste import RandomPaste
|
||||||
|
|
64
src/utils/callback.py
Normal file
64
src/utils/callback.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
from torch import tensor
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
columns = [
|
||||||
|
"ID",
|
||||||
|
"image",
|
||||||
|
"ground truth",
|
||||||
|
"prediction",
|
||||||
|
"dice",
|
||||||
|
"dice_bin",
|
||||||
|
]
|
||||||
|
class_labels = {
|
||||||
|
1: "sphere",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TableLog(Callback):
|
||||||
|
def on_validation_epoch_start(self, trainer, pl_module):
|
||||||
|
self.rows = []
|
||||||
|
|
||||||
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
# unpacking
|
||||||
|
images, ground_truth = batch
|
||||||
|
metrics, predictions = outputs
|
||||||
|
|
||||||
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
zip(
|
||||||
|
images.cpu(),
|
||||||
|
ground_truth.cpu(),
|
||||||
|
predictions["linear"].cpu(),
|
||||||
|
predictions["binary"].cpu().squeeze(1).int().numpy(),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
self.rows.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
metrics["dice"],
|
||||||
|
metrics["dice_bin"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
|
# log table
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"val/predictions": wandb.Table(
|
||||||
|
columns=columns,
|
||||||
|
data=self.rows,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
Loading…
Reference in a new issue