refactor: code splitting
Former-commit-id: 7b293e392cc7d4135ef8562faece6f491c623718 [formerly 381a418ceab2cb7f367f07b9f3ea4f3c6a41ecac] Former-commit-id: c2ddd57c4a3592c93640170a737506bb64b60864
This commit is contained in:
parent
ebb213c565
commit
5f46efa5a1
|
@ -1 +1 @@
|
||||||
0f3136c724eea42fdf1ee15e721ef33604e9a46d
|
ac8ff07f541ae6d7cba729b20e0d04654c6018c9
|
1
src/data/__init__.py
Normal file
1
src/data/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .dataloader import SyntheticSphere
|
50
src/data/dataloader.py
Normal file
50
src/data/dataloader.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import albumentations as A
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
from utils import RandomPaste
|
||||||
|
|
||||||
|
from .dataset import SphereDataset
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticSphere(pl.LightningDataModule):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
tf_train = A.Compose(
|
||||||
|
[
|
||||||
|
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||||
|
A.Flip(),
|
||||||
|
A.ColorJitter(),
|
||||||
|
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||||
|
A.GaussianBlur(),
|
||||||
|
A.ISONoise(),
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
|
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
ds_train,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=wandb.config.BATCH_SIZE,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
ds_valid,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
)
|
14
src/train.py
14
src/train.py
|
@ -1,7 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
|
@ -57,13 +56,17 @@ if __name__ == "__main__":
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(net, log="all")
|
logger.watch(net, log="all")
|
||||||
|
|
||||||
|
# create checkpoint callback
|
||||||
|
checkpoint_callback = pl.ModelCheckpoint(
|
||||||
|
dirpath="checkpoints",
|
||||||
|
monitor="val/dice",
|
||||||
|
)
|
||||||
|
|
||||||
# Create the trainer
|
# Create the trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=CONFIG["EPOCHS"],
|
max_epochs=CONFIG["EPOCHS"],
|
||||||
accelerator=CONFIG["DEVICE"],
|
accelerator=CONFIG["DEVICE"],
|
||||||
# precision=16,
|
# precision=16,
|
||||||
# auto_scale_batch_size="binsearch",
|
|
||||||
# auto_lr_find=True,
|
|
||||||
benchmark=CONFIG["BENCHMARK"],
|
benchmark=CONFIG["BENCHMARK"],
|
||||||
val_check_interval=100,
|
val_check_interval=100,
|
||||||
callbacks=RichProgressBar(),
|
callbacks=RichProgressBar(),
|
||||||
|
@ -71,12 +74,7 @@ if __name__ == "__main__":
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
trainer.tune(net)
|
|
||||||
trainer.fit(model=net)
|
trainer.fit(model=net)
|
||||||
except KeyboardInterrupt:
|
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .model import UNet
|
from .module import UNetModule
|
||||||
|
|
|
@ -1,37 +1,14 @@
|
||||||
""" Full assembly of the parts to form the complete network """
|
"""Full assembly of the parts to form the complete network."""
|
||||||
|
|
||||||
import itertools
|
import torch.nn as nn
|
||||||
|
|
||||||
import albumentations as A
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
import wandb
|
|
||||||
from src.utils.dataset import SphereDataset
|
|
||||||
from utils.dice import dice_loss
|
|
||||||
from utils.paste import RandomPaste
|
|
||||||
|
|
||||||
from .blocks import *
|
from .blocks import *
|
||||||
|
|
||||||
class_labels = {
|
|
||||||
1: "sphere",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
class UNet(pl.LightningModule):
|
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||||
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
self.n_channels = n_channels
|
|
||||||
self.n_classes = n_classes
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
# log hyperparameters
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
# Network
|
# Network
|
||||||
self.inc = DoubleConv(n_channels, features[0])
|
self.inc = DoubleConv(n_channels, features[0])
|
||||||
|
|
||||||
|
@ -66,224 +43,3 @@ class UNet(pl.LightningModule):
|
||||||
x = self.outc(x)
|
x = self.outc(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
tf_train = A.Compose(
|
|
||||||
[
|
|
||||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
|
||||||
A.Flip(),
|
|
||||||
A.ColorJitter(),
|
|
||||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
|
||||||
A.GaussianBlur(),
|
|
||||||
A.ISONoise(),
|
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
|
||||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
ds_train,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=wandb.config.WORKERS,
|
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
|
||||||
)
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
|
||||||
|
|
||||||
return DataLoader(
|
|
||||||
ds_valid,
|
|
||||||
shuffle=False,
|
|
||||||
batch_size=1,
|
|
||||||
num_workers=wandb.config.WORKERS,
|
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
|
||||||
)
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
# unpacking
|
|
||||||
images, masks_true = batch
|
|
||||||
masks_true = masks_true.unsqueeze(1)
|
|
||||||
|
|
||||||
# forward pass
|
|
||||||
masks_pred = self(images)
|
|
||||||
|
|
||||||
# compute metrics
|
|
||||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
|
||||||
dice = dice_loss(masks_pred, masks_true)
|
|
||||||
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
|
||||||
|
|
||||||
self.log_dict(
|
|
||||||
{
|
|
||||||
"train/accuracy": accuracy,
|
|
||||||
"train/dice": dice,
|
|
||||||
"train/dice_bin": dice_bin,
|
|
||||||
"train/bce": bce,
|
|
||||||
"train/mae": mae,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_idx == 22000:
|
|
||||||
rows = []
|
|
||||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
|
||||||
zip(
|
|
||||||
images.cpu(),
|
|
||||||
masks_true.cpu(),
|
|
||||||
masks_pred.cpu(),
|
|
||||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
rows.append(
|
|
||||||
[
|
|
||||||
i,
|
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
|
||||||
pred,
|
|
||||||
masks={
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_bin,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
dice,
|
|
||||||
dice_bin,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# logging
|
|
||||||
try: # required by autofinding, logger replaced by dummy
|
|
||||||
self.logger.log_table(
|
|
||||||
key="train/predictions",
|
|
||||||
columns=columns,
|
|
||||||
data=rows,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
accuracy=accuracy,
|
|
||||||
loss=dice,
|
|
||||||
bce=bce,
|
|
||||||
mae=mae,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
# unpacking
|
|
||||||
images, masks_true = batch
|
|
||||||
masks_true = masks_true.unsqueeze(1)
|
|
||||||
|
|
||||||
# forward pass
|
|
||||||
masks_pred = self(images)
|
|
||||||
|
|
||||||
# compute metrics
|
|
||||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
|
||||||
dice = dice_loss(masks_pred, masks_true)
|
|
||||||
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
if batch_idx % 50 == 0 or dice > 0.9:
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
|
||||||
zip(
|
|
||||||
images.cpu(),
|
|
||||||
masks_true.cpu(),
|
|
||||||
masks_pred.cpu(),
|
|
||||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
rows.append(
|
|
||||||
[
|
|
||||||
i,
|
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
|
||||||
pred,
|
|
||||||
masks={
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_bin,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
dice,
|
|
||||||
dice_bin,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return dict(
|
|
||||||
accuracy=accuracy,
|
|
||||||
loss=dice,
|
|
||||||
dice_bin=dice_bin,
|
|
||||||
bce=bce,
|
|
||||||
mae=mae,
|
|
||||||
table_rows=rows,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validation_epoch_end(self, validation_outputs):
|
|
||||||
# matrics unpacking
|
|
||||||
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
|
||||||
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
|
|
||||||
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
|
||||||
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
|
|
||||||
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
|
||||||
|
|
||||||
# table unpacking
|
|
||||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
|
||||||
rowss = [d["table_rows"] for d in validation_outputs]
|
|
||||||
rows = list(itertools.chain.from_iterable(rowss))
|
|
||||||
|
|
||||||
# logging
|
|
||||||
try: # required by autofinding, logger replaced by dummy
|
|
||||||
self.logger.log_table(
|
|
||||||
key="val/predictions",
|
|
||||||
columns=columns,
|
|
||||||
data=rows,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.log_dict(
|
|
||||||
{
|
|
||||||
"val/accuracy": accuracy,
|
|
||||||
"val/dice": loss,
|
|
||||||
"val/dice_bin": dice_bin,
|
|
||||||
"val/bce": bce,
|
|
||||||
"val/mae": mae,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# export model to pth
|
|
||||||
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
|
||||||
artifact.add_file("checkpoints/model.pth")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
# export model to onnx
|
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
|
||||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
|
||||||
artifact.add_file("checkpoints/model.onnx")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.RMSprop(
|
|
||||||
self.parameters(),
|
|
||||||
lr=self.learning_rate,
|
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
|
||||||
momentum=wandb.config.MOMENTUM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
193
src/unet/module.py
Normal file
193
src/unet/module.py
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
"""Pytorch lightning wrapper for model."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
from unet.model import UNet
|
||||||
|
from utils.dice import dice_loss
|
||||||
|
|
||||||
|
from .blocks import *
|
||||||
|
|
||||||
|
class_labels = {
|
||||||
|
1: "sphere",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class UNetModule(pl.LightningModule):
|
||||||
|
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
||||||
|
super(UNetModule, self).__init__()
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
self.n_channels = n_channels
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# log hyperparameters
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
# Network
|
||||||
|
self.model = UNet(n_channels, n_classes, features)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
def shared_step(self, batch):
|
||||||
|
data, ground_truth = batch # unpacking
|
||||||
|
ground_truth = ground_truth.unsqueeze(1) # 1HW -> HW
|
||||||
|
|
||||||
|
# forward pass, compute masks
|
||||||
|
prediction = self.model(data)
|
||||||
|
binary = (torch.sigmoid(prediction) > 0.5).float() # TODO: check if float necessary
|
||||||
|
|
||||||
|
# compute metrics (in dictionnary)
|
||||||
|
metrics = {
|
||||||
|
"dice": dice_loss(prediction, ground_truth),
|
||||||
|
"dice_bin": dice_loss(binary, ground_truth, logits=False),
|
||||||
|
"bce": F.binary_cross_entropy_with_logits(prediction, ground_truth),
|
||||||
|
"mae": torch.nn.functional.l1_loss(binary, ground_truth),
|
||||||
|
"accuracy": (ground_truth == binary).float().mean(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# wrap tensors in dictionnary
|
||||||
|
tensors = {
|
||||||
|
"data": data,
|
||||||
|
"ground_truth": ground_truth,
|
||||||
|
"prediction": prediction,
|
||||||
|
"binary": binary,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics, tensors
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# compute metrics
|
||||||
|
metrics, tensors = self.shared_step(batch)
|
||||||
|
|
||||||
|
# log metrics
|
||||||
|
self.log_dict(dict([(f"train/{key}", value) for key, value in metrics.items()]))
|
||||||
|
|
||||||
|
if batch_idx == 5000:
|
||||||
|
rows = []
|
||||||
|
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||||
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
zip( # TODO: use comprehension list to zip the dictionnary
|
||||||
|
tensors["images"].cpu(),
|
||||||
|
tensors["ground_truth"].cpu(),
|
||||||
|
tensors["prediction"].cpu(),
|
||||||
|
tensors["binary"]
|
||||||
|
.cpu()
|
||||||
|
.squeeze(1)
|
||||||
|
.int()
|
||||||
|
.numpy(), # TODO: check if .functions can be moved elsewhere
|
||||||
|
)
|
||||||
|
):
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
metrics["dice"],
|
||||||
|
metrics["dice_bin"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# log table
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"train/predictions": wandb.Table(
|
||||||
|
columns=columns,
|
||||||
|
data=rows,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics["dice"]
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
metrics, tensors = self.shared_step(batch)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
||||||
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
zip( # TODO: use comprehension list to zip the dictionnary
|
||||||
|
tensors["images"].cpu(),
|
||||||
|
tensors["ground_truth"].cpu(),
|
||||||
|
tensors["prediction"].cpu(),
|
||||||
|
tensors["binary"]
|
||||||
|
.cpu()
|
||||||
|
.squeeze(1)
|
||||||
|
.int()
|
||||||
|
.numpy(), # TODO: check if .functions can be moved elsewhere
|
||||||
|
)
|
||||||
|
):
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
metrics["dice"],
|
||||||
|
metrics["dice_bin"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def validation_epoch_end(self, validation_outputs):
|
||||||
|
# unpacking
|
||||||
|
metricss, rowss = validation_outputs
|
||||||
|
|
||||||
|
# metrics flattening
|
||||||
|
metrics = {
|
||||||
|
"dice": torch.stack([d["dice"] for d in metricss]).mean(),
|
||||||
|
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
||||||
|
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
||||||
|
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
||||||
|
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# log metrics
|
||||||
|
self.log_dict(dict([(f"val/{key}", value) for key, value in metrics.items()]))
|
||||||
|
|
||||||
|
# rows flattening
|
||||||
|
rows = list(itertools.chain.from_iterable(rowss))
|
||||||
|
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||||
|
|
||||||
|
# log table
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"val/predictions": wandb.Table(
|
||||||
|
columns=columns,
|
||||||
|
data=rows,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.RMSprop(
|
||||||
|
self.parameters(),
|
||||||
|
lr=self.learning_rate,
|
||||||
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
|
momentum=wandb.config.MOMENTUM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer
|
|
@ -0,0 +1 @@
|
||||||
|
from .paste import RandomPaste
|
Loading…
Reference in a new issue