feat(WIP): switching to pytorch lightning
Former-commit-id: 0038dbca182717af8fc4bd846fd5be0e9fa70a9a [formerly eb5eb0717f8511bf49de8393bbdc66e727b930ff] Former-commit-id: 540304228b146fe8e086bc4ccb770a13f84cbbcb
This commit is contained in:
parent
d785a5c6be
commit
982dfe99d7
275
src/train.py
275
src/train.py
|
@ -1,16 +1,16 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.utils.dataset import SphereDataset
|
from src.utils.dataset import SphereDataset
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.dice import dice_coeff
|
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
class_labels = {
|
class_labels = {
|
||||||
|
@ -22,7 +22,7 @@ if __name__ == "__main__":
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
wandb.init(
|
logger = WandbLogger(
|
||||||
project="U-Net",
|
project="U-Net",
|
||||||
config=dict(
|
config=dict(
|
||||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
||||||
|
@ -36,7 +36,7 @@ if __name__ == "__main__":
|
||||||
AMP=True,
|
AMP=True,
|
||||||
PIN_MEMORY=True,
|
PIN_MEMORY=True,
|
||||||
BENCHMARK=True,
|
BENCHMARK=True,
|
||||||
DEVICE="cuda",
|
DEVICE="gpu",
|
||||||
WORKERS=8,
|
WORKERS=8,
|
||||||
EPOCHS=5,
|
EPOCHS=5,
|
||||||
BATCH_SIZE=16,
|
BATCH_SIZE=16,
|
||||||
|
@ -51,18 +51,17 @@ if __name__ == "__main__":
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# create device
|
# seed random generators
|
||||||
device = torch.device(wandb.config.DEVICE)
|
pl.seed_everything(69420, workers=True)
|
||||||
|
|
||||||
# enable cudnn benchmarking
|
|
||||||
torch.backends.cudnn.benchmark = wandb.config.BENCHMARK
|
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||||
|
|
||||||
|
# log the number of parameters of the model
|
||||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
|
|
||||||
# transfer network to device
|
# log gradients and weights regularly
|
||||||
net.to(device=device)
|
logger.watch(net, log="all")
|
||||||
|
|
||||||
# 1. Create transforms
|
# 1. Create transforms
|
||||||
tf_train = A.Compose(
|
tf_train = A.Compose(
|
||||||
|
@ -121,244 +120,38 @@ if __name__ == "__main__":
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
# 4. Create the trainer
|
||||||
optimizer = torch.optim.RMSprop(
|
trainer = pl.Trainer(
|
||||||
net.parameters(),
|
max_epochs=wandb.config.EPOCHS,
|
||||||
lr=wandb.config.LEARNING_RATE,
|
accelerator="gpu",
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
precision=16,
|
||||||
momentum=wandb.config.MOMENTUM,
|
auto_scale_batch_size="binsearch",
|
||||||
|
benchmark=wandb.config.BENCHMARK,
|
||||||
|
val_check_interval=100,
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
|
||||||
|
|
||||||
# save model.onxx
|
|
||||||
dummy_input = torch.randn(
|
|
||||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
|
||||||
).to(device)
|
|
||||||
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
|
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
|
||||||
artifact.add_file("checkpoints/model-0.onnx")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
# log gradients and weights four time per epoch
|
|
||||||
wandb.watch(net, criterion, log_freq=100)
|
|
||||||
|
|
||||||
# print the config
|
# print the config
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
|
|
||||||
# wandb init log
|
# # wandb init log
|
||||||
wandb.log(
|
# wandb.log(
|
||||||
{
|
# {
|
||||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
# "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||||
},
|
# },
|
||||||
commit=False,
|
# commit=False,
|
||||||
)
|
# )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
trainer.fit(
|
||||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
model=net,
|
||||||
|
train_dataloaders=train_loader,
|
||||||
# Training round
|
val_dataloaders=val_loader,
|
||||||
for step, (images, true_masks) in enumerate(train_loader):
|
test_dataloaders=test_loader,
|
||||||
assert images.shape[1] == net.n_channels, (
|
accelerator=wandb.config.DEVICE,
|
||||||
f"Network has been defined with {net.n_channels} input channels, "
|
)
|
||||||
f"but loaded images have {images.shape[1]} channels. Please check that "
|
|
||||||
"the images are loaded correctly."
|
|
||||||
)
|
|
||||||
|
|
||||||
# transfer images to device
|
|
||||||
images = images.to(device=device)
|
|
||||||
true_masks = true_masks.unsqueeze(1).to(device=device)
|
|
||||||
|
|
||||||
# forward
|
|
||||||
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
|
||||||
pred_masks = net(images)
|
|
||||||
train_loss = criterion(pred_masks, true_masks)
|
|
||||||
|
|
||||||
# backward
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
grad_scaler.scale(train_loss).backward()
|
|
||||||
grad_scaler.step(optimizer)
|
|
||||||
grad_scaler.update()
|
|
||||||
|
|
||||||
# compute metrics
|
|
||||||
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
|
||||||
accuracy = (true_masks == pred_masks_bin).float().mean()
|
|
||||||
dice = dice_coeff(pred_masks_bin, true_masks)
|
|
||||||
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
|
||||||
|
|
||||||
# update tqdm progress bar
|
|
||||||
pbar.update(images.shape[0])
|
|
||||||
pbar.set_postfix(**{"loss": train_loss.item()})
|
|
||||||
|
|
||||||
# log metrics
|
|
||||||
wandb.log(
|
|
||||||
{
|
|
||||||
"epoch": epoch - 1 + step / len(train_loader),
|
|
||||||
"train/accuracy": accuracy,
|
|
||||||
"train/bce": train_loss,
|
|
||||||
"train/dice": dice,
|
|
||||||
"train/mae": mae,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if step and (step % 250 == 0 or step == len(train_loader)):
|
|
||||||
# Evaluation round
|
|
||||||
net.eval()
|
|
||||||
accuracy = 0
|
|
||||||
val_loss = 0
|
|
||||||
dice = 0
|
|
||||||
mae = 0
|
|
||||||
with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2:
|
|
||||||
for images, masks_true in val_loader:
|
|
||||||
|
|
||||||
# transfer images to device
|
|
||||||
images = images.to(device=device)
|
|
||||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
|
||||||
|
|
||||||
# forward
|
|
||||||
with torch.inference_mode():
|
|
||||||
masks_pred = net(images)
|
|
||||||
|
|
||||||
# compute metrics
|
|
||||||
val_loss += criterion(masks_pred, masks_true)
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
|
||||||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
# update progress bar
|
|
||||||
pbar2.update(images.shape[0])
|
|
||||||
|
|
||||||
accuracy /= len(val_loader)
|
|
||||||
val_loss /= len(val_loader)
|
|
||||||
dice /= len(val_loader)
|
|
||||||
mae /= len(val_loader)
|
|
||||||
|
|
||||||
# save the last validation batch to table
|
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
|
||||||
zip(
|
|
||||||
images.cpu(),
|
|
||||||
masks_true.cpu(),
|
|
||||||
masks_pred.cpu(),
|
|
||||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
table.add_data(
|
|
||||||
i,
|
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
|
||||||
pred,
|
|
||||||
masks={
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_bin,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# log validation metrics
|
|
||||||
wandb.log(
|
|
||||||
{
|
|
||||||
"val/predictions": table,
|
|
||||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
|
||||||
"val/accuracy": accuracy,
|
|
||||||
"val/bce": val_loss,
|
|
||||||
"val/dice": dice,
|
|
||||||
"val/mae": mae,
|
|
||||||
},
|
|
||||||
commit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# update hyperparameters
|
|
||||||
net.train()
|
|
||||||
scheduler.step(dice)
|
|
||||||
|
|
||||||
# export model to onnx format when validation ends
|
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
|
||||||
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
|
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
|
||||||
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
# testing round
|
|
||||||
net.eval()
|
|
||||||
accuracy = 0
|
|
||||||
val_loss = 0
|
|
||||||
dice = 0
|
|
||||||
mae = 0
|
|
||||||
with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3:
|
|
||||||
for images, masks_true in test_loader:
|
|
||||||
|
|
||||||
# transfer images to device
|
|
||||||
images = images.to(device=device)
|
|
||||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
|
||||||
|
|
||||||
# forward
|
|
||||||
with torch.inference_mode():
|
|
||||||
masks_pred = net(images)
|
|
||||||
|
|
||||||
# compute metrics
|
|
||||||
val_loss += criterion(masks_pred, masks_true)
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
|
||||||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
|
||||||
|
|
||||||
# update progress bar
|
|
||||||
pbar3.update(images.shape[0])
|
|
||||||
|
|
||||||
accuracy /= len(test_loader)
|
|
||||||
val_loss /= len(test_loader)
|
|
||||||
dice /= len(test_loader)
|
|
||||||
mae /= len(test_loader)
|
|
||||||
|
|
||||||
# save the last validation batch to table
|
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
|
||||||
zip(
|
|
||||||
images.cpu(),
|
|
||||||
masks_true.cpu(),
|
|
||||||
masks_pred.cpu(),
|
|
||||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
|
||||||
)
|
|
||||||
):
|
|
||||||
table.add_data(
|
|
||||||
i,
|
|
||||||
wandb.Image(img),
|
|
||||||
wandb.Image(mask),
|
|
||||||
wandb.Image(
|
|
||||||
pred,
|
|
||||||
masks={
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_bin,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# log validation metrics
|
|
||||||
wandb.log(
|
|
||||||
{
|
|
||||||
"test/predictions": table,
|
|
||||||
"test/accuracy": accuracy,
|
|
||||||
"test/bce": val_loss,
|
|
||||||
"test/dice": dice,
|
|
||||||
"test/mae": mae,
|
|
||||||
},
|
|
||||||
commit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# stop wandb
|
|
||||||
wandb.run.finish()
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# sapin de noel
|
# stop wandb
|
||||||
|
wandb.run.finish()
|
||||||
|
|
|
@ -1,9 +1,21 @@
|
||||||
""" Full assembly of the parts to form the complete network """
|
""" Full assembly of the parts to form the complete network """
|
||||||
|
|
||||||
|
from xmlrpc.server import list_public_methods
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
from utils.dice import dice_coeff
|
||||||
|
|
||||||
from .blocks import *
|
from .blocks import *
|
||||||
|
|
||||||
|
class_labels = {
|
||||||
|
1: "sphere",
|
||||||
|
}
|
||||||
|
|
||||||
class UNet(nn.Module):
|
|
||||||
|
class UNet(pl.LightningModule):
|
||||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
self.n_channels = n_channels
|
self.n_channels = n_channels
|
||||||
|
@ -26,7 +38,6 @@ class UNet(nn.Module):
|
||||||
self.outc = OutConv(features[0], n_classes)
|
self.outc = OutConv(features[0], n_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
skips = []
|
skips = []
|
||||||
|
|
||||||
x = self.inc(x)
|
x = self.inc(x)
|
||||||
|
@ -41,3 +52,158 @@ class UNet(nn.Module):
|
||||||
x = self.outc(x)
|
x = self.outc(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key):
|
||||||
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
|
|
||||||
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
zip(
|
||||||
|
images.cpu(),
|
||||||
|
masks_true.cpu(),
|
||||||
|
masks_pred.cpu(),
|
||||||
|
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
table.add_data(
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
log_key: table,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# unpacking
|
||||||
|
images, masks_true = batch
|
||||||
|
masks_true = masks_true.unsqueeze(1)
|
||||||
|
masks_pred = self(images)
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
|
||||||
|
# compute metrics
|
||||||
|
loss = F.cross_entropy(masks_pred, masks_true)
|
||||||
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"train/accuracy": accuracy,
|
||||||
|
"train/bce": loss,
|
||||||
|
"train/dice": dice,
|
||||||
|
"train/mae": mae,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss, dice, accuracy, mae
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
# unpacking
|
||||||
|
images, masks_true = batch
|
||||||
|
masks_true = masks_true.unsqueeze(1)
|
||||||
|
masks_pred = self(images)
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
|
||||||
|
# compute metrics
|
||||||
|
loss = F.cross_entropy(masks_pred, masks_true)
|
||||||
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
|
if batch_idx == 0:
|
||||||
|
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
|
||||||
|
|
||||||
|
return loss, dice, accuracy, mae
|
||||||
|
|
||||||
|
def validation_step_end(self, validation_outputs):
|
||||||
|
# unpacking
|
||||||
|
loss, dice, accuracy, mae = validation_outputs
|
||||||
|
optimizer = self.optimizers[0]
|
||||||
|
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"train/learning_rate": learning_rate,
|
||||||
|
"val/accuracy": accuracy,
|
||||||
|
"val/bce": loss,
|
||||||
|
"val/dice": dice,
|
||||||
|
"val/mae": mae,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# export model to onnx
|
||||||
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||||
|
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||||
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
|
artifact.add_file(f"checkpoints/model.onnx")
|
||||||
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
# unpacking
|
||||||
|
images, masks_true = batch
|
||||||
|
masks_true = masks_true.unsqueeze(1)
|
||||||
|
masks_pred = self(images)
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
|
||||||
|
# compute metrics
|
||||||
|
loss = F.cross_entropy(masks_pred, masks_true)
|
||||||
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
|
if batch_idx == 0:
|
||||||
|
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
||||||
|
|
||||||
|
return loss, dice, accuracy, mae
|
||||||
|
|
||||||
|
def test_step_end(self, test_outputs):
|
||||||
|
# unpacking
|
||||||
|
list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
||||||
|
|
||||||
|
# averaging
|
||||||
|
loss = np.mean(list_loss)
|
||||||
|
dice = np.mean(list_dice)
|
||||||
|
accuracy = np.mean(list_accuracy)
|
||||||
|
mae = np.mean(list_mae)
|
||||||
|
|
||||||
|
# get learning rate
|
||||||
|
optimizer = self.optimizers[0]
|
||||||
|
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"train/learning_rate": learning_rate,
|
||||||
|
"val/accuracy": accuracy,
|
||||||
|
"val/bce": loss,
|
||||||
|
"val/dice": dice,
|
||||||
|
"val/mae": mae,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.RMSprop(
|
||||||
|
self.parameters(),
|
||||||
|
lr=wandb.config.LEARNING_RATE,
|
||||||
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
|
momentum=wandb.config.MOMENTUM,
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer,
|
||||||
|
"max",
|
||||||
|
patience=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
Loading…
Reference in a new issue