feat: ugly training image logging

Former-commit-id: 16a25008320f436069cff9f44bf013c1c2d0f890 [formerly 683afc2cb6322ce3f1d98797b947cca8c6af09a4]
Former-commit-id: a5dae735e10107b514f028e84084ce7a303216ef
This commit is contained in:
Laurent Fainsin 2022-07-08 09:54:45 +02:00
parent 8611d8cd7a
commit 90978bfdc3
3 changed files with 43 additions and 4 deletions

View file

@ -4,7 +4,6 @@ import pytorch_lightning as pl
import torch import torch
from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
import wandb import wandb
from unet import UNet from unet import UNet
@ -13,7 +12,7 @@ CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/", "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres/",
"FEATURES": [8, 16, 32, 64], "FEATURES": [8, 16, 32, 64],
"N_CHANNELS": 3, "N_CHANNELS": 3,
"N_CLASSES": 1, "N_CLASSES": 1,

View file

@ -1,4 +1,4 @@
""" Parts of the U-Net model """ """Parts of the U-Net model."""
import torch import torch
import torch.nn as nn import torch.nn as nn

View file

@ -130,6 +130,46 @@ class UNet(pl.LightningModule):
}, },
) )
if batch_idx == 22000:
rows = []
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
masks_true.cpu(),
masks_pred.cpu(),
masks_pred_bin.cpu().squeeze(1).int().numpy(),
)
):
rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
),
dice,
dice_bin,
]
)
# logging
try: # required by autofinding, logger replaced by dummy
self.logger.log_table(
key="train/predictions",
columns=columns,
data=rows,
)
except:
pass
return dict( return dict(
accuracy=accuracy, accuracy=accuracy,
loss=dice, loss=dice,
@ -155,7 +195,7 @@ class UNet(pl.LightningModule):
accuracy = (masks_true == masks_pred_bin).float().mean() accuracy = (masks_true == masks_pred_bin).float().mean()
rows = [] rows = []
if batch_idx % 50 == 0: if batch_idx % 50 == 0 or dice < 0.1:
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( zip(
images.cpu(), images.cpu(),