diff --git a/src/train.py b/src/train.py index bb0c2e4..d111e38 100644 --- a/src/train.py +++ b/src/train.py @@ -4,7 +4,6 @@ import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger -from torch.utils.data import DataLoader import wandb from unet import UNet @@ -13,7 +12,7 @@ CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", - "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres_prod/", + "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres/", "FEATURES": [8, 16, 32, 64], "N_CHANNELS": 3, "N_CLASSES": 1, diff --git a/src/unet/blocks.py b/src/unet/blocks.py index 1f4a854..d125002 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -1,4 +1,4 @@ -""" Parts of the U-Net model """ +"""Parts of the U-Net model.""" import torch import torch.nn as nn diff --git a/src/unet/model.py b/src/unet/model.py index 8d5288f..11ddc65 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -130,6 +130,46 @@ class UNet(pl.LightningModule): }, ) + if batch_idx == 22000: + rows = [] + columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), + ) + ): + rows.append( + [ + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + dice, + dice_bin, + ] + ) + + # logging + try: # required by autofinding, logger replaced by dummy + self.logger.log_table( + key="train/predictions", + columns=columns, + data=rows, + ) + except: + pass + return dict( accuracy=accuracy, loss=dice, @@ -155,7 +195,7 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() rows = [] - if batch_idx % 50 == 0: + if batch_idx % 50 == 0 or dice < 0.1: for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(),