mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: ugly training image logging
Former-commit-id: 16a25008320f436069cff9f44bf013c1c2d0f890 [formerly 683afc2cb6322ce3f1d98797b947cca8c6af09a4] Former-commit-id: a5dae735e10107b514f028e84084ce7a303216ef
This commit is contained in:
parent
8611d8cd7a
commit
90978bfdc3
|
@ -4,7 +4,6 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
|
@ -13,7 +12,7 @@ CONFIG = {
|
||||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||||
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
|
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
|
||||||
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
||||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/",
|
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres/",
|
||||||
"FEATURES": [8, 16, 32, 64],
|
"FEATURES": [8, 16, 32, 64],
|
||||||
"N_CHANNELS": 3,
|
"N_CHANNELS": 3,
|
||||||
"N_CLASSES": 1,
|
"N_CLASSES": 1,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
""" Parts of the U-Net model """
|
"""Parts of the U-Net model."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
|
@ -130,6 +130,46 @@ class UNet(pl.LightningModule):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch_idx == 22000:
|
||||||
|
rows = []
|
||||||
|
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||||
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
zip(
|
||||||
|
images.cpu(),
|
||||||
|
masks_true.cpu(),
|
||||||
|
masks_pred.cpu(),
|
||||||
|
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
dice,
|
||||||
|
dice_bin,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging
|
||||||
|
try: # required by autofinding, logger replaced by dummy
|
||||||
|
self.logger.log_table(
|
||||||
|
key="train/predictions",
|
||||||
|
columns=columns,
|
||||||
|
data=rows,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
accuracy=accuracy,
|
accuracy=accuracy,
|
||||||
loss=dice,
|
loss=dice,
|
||||||
|
@ -155,7 +195,7 @@ class UNet(pl.LightningModule):
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
if batch_idx % 50 == 0:
|
if batch_idx % 50 == 0 or dice < 0.1:
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.cpu(),
|
images.cpu(),
|
||||||
|
|
Loading…
Reference in a new issue