feat: checkpoint wandb logging

feat: wandb config file

Former-commit-id: 45f56db86ca269b028cf76bf5315bc0eef8d2a21 [formerly e320b72e16eed02bdca05245e7c77914f0e288f9]
Former-commit-id: 6d91318784748308c73dc6a164653f04ae46cd2a
This commit is contained in:
Laurent Fainsin 2022-07-11 15:34:05 +02:00
parent 4015dad491
commit 47c888cf6c
7 changed files with 119 additions and 91 deletions

View file

@ -11,5 +11,5 @@ charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.{json,toml}]
[*.{json,toml,yaml,yml}]
indent_size = 2

2
.vscode/launch.json vendored
View file

@ -18,7 +18,7 @@
"--model",
"good.onnx",
],
"justMyCode": true
"justMyCode": false
}
]
}

41
src/config-defaults.yaml Normal file
View file

@ -0,0 +1,41 @@
DIR_TRAIN_IMG:
value: "/home/lilian/data_disk/lfainsin/train/"
DIR_VALID_IMG:
value: "/home/lilian/data_disk/lfainsin/test_batched_fast/"
DIR_SPHERE:
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
FEATURES:
value: { 8, 16, 32, 64 }
N_CHANNELS:
value: 3,
N_CLASSES:
value: 1,
AMP:
value: True
PIN_MEMORY:
value: True
BENCHMARK:
value: True
DEVICE:
value: gpu
WORKERS:
value: 8
IMG_SIZE:
value: 512
SPHERES:
value: 5
EPOCHS:
value: 10
BATCH_SIZE:
value: 16
LEARNING_RATE:
value: 1e-4
WEIGHT_DECAY:
value: 1e-8
MOMENTUM:
value: 0.9

View file

@ -36,19 +36,6 @@ class Spheres(pl.LightningDataModule):
pin_memory=wandb.config.PIN_MEMORY,
)
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
# return DataLoader(
# dataset,
# shuffle=True,
# batch_size=8,
# prefetch_factor=8,
# num_workers=wandb.config.WORKERS,
# pin_memory=wandb.config.PIN_MEMORY,
# )
def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
@ -56,7 +43,7 @@ class Spheres(pl.LightningDataModule):
return DataLoader(
dataset,
shuffle=False,
batch_size=1,
batch_size=8,
prefetch_factor=8,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,

View file

@ -1,44 +1,21 @@
import logging
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import wandb
from data import Spheres
from unet import UNetModule
from utils import TableLog
CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
"FEATURES": [8, 16, 32, 64],
"N_CHANNELS": 3,
"N_CLASSES": 1,
"AMP": True,
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 8,
"EPOCHS": 10,
"BATCH_SIZE": 16,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
"IMG_SIZE": 512,
"SPHERES": 5,
}
from utils import ArtifactLog, TableLog
if __name__ == "__main__":
# setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb
# setup wandb, config loaded from config-default.yaml
logger = WandbLogger(
project="U-Net",
config=CONFIG,
settings=wandb.Settings(
code_dir="./src/",
),
@ -49,44 +26,38 @@ if __name__ == "__main__":
# Create network
model = UNetModule(
n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"],
batch_size=CONFIG["BATCH_SIZE"],
learning_rate=CONFIG["LEARNING_RATE"],
features=CONFIG["FEATURES"],
n_channels=wandb.config.N_CHANNELS,
n_classes=wandb.config.N_CLASSES,
batch_size=wandb.config.BATCH_SIZE,
learning_rate=wandb.config.LEARNING_RATE,
features=wandb.config.FEATURES,
)
# load checkpoint
state_dict = torch.load("checkpoints/synth.pth")
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
model.load_state_dict(state_dict)
# state_dict = torch.load("checkpoints/synth.pth")
# state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
# model.load_state_dict(state_dict)
# log gradients and weights regularly
logger.watch(model, log="all")
# create checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="model.ckpt",
monitor="val/dice",
)
# Create the dataloaders
datamodule = Spheres()
# Create the trainer
trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"],
benchmark=CONFIG["BENCHMARK"],
max_epochs=wandb.config.EPOCHS,
accelerator=wandb.config.DEVICE,
benchmark=wandb.config.BENCHMARK,
# profiler="simple",
# precision=16,
logger=logger,
log_every_n_steps=1,
val_check_interval=25,
callbacks=[RichProgressBar(), checkpoint_callback, TableLog()],
val_check_interval=100,
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
)
# actually train the model
trainer.fit(model=model, datamodule=datamodule)
# stop wandb

View file

@ -1,2 +1,2 @@
from .callback import TableLog
from .callback import ArtifactLog, TableLog
from .paste import RandomPaste

View file

@ -1,5 +1,6 @@
import numpy as np
import torch
from pytorch_lightning.callbacks import Callback
from torch import tensor
import wandb
@ -22,35 +23,36 @@ class TableLog(Callback):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking
images, ground_truth = batch
metrics, predictions = outputs
if batch_idx == 0:
images, ground_truth = batch
metrics, predictions = outputs
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
ground_truth.cpu(),
predictions["linear"].cpu(),
predictions["binary"].cpu().squeeze(1).int().numpy(),
)
):
self.rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
ground_truth.cpu(),
predictions["linear"].cpu(),
predictions["binary"].cpu().squeeze(1).int().numpy(),
)
):
self.rows.append(
[
i,
wandb.Image(img),
wandb.Image(mask),
wandb.Image(
pred,
masks={
"predictions": {
"mask_data": pred_bin,
"class_labels": class_labels,
},
},
},
),
metrics["dice"],
metrics["dice_bin"],
]
)
),
metrics["dice"],
metrics["dice_bin"],
]
)
def on_validation_epoch_end(self, trainer, pl_module):
# log table
@ -62,3 +64,30 @@ class TableLog(Callback):
)
}
)
class ArtifactLog(Callback):
def on_validation_epoch_start(self, trainer, pl_module):
self.dices = []
self.best = 1
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking
metrics, _ = outputs
self.dices.append(metrics["dice"].cpu())
def on_validation_epoch_end(self, trainer, pl_module):
dice = np.mean(self.dices)
self.dices = []
if dice < self.best:
self.best = dice
# create checkpoint
torch.save(self.state_dict(), "checkpoints/model.pth")
# trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt
# create and log artifact
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("checkpoints/model.pth")
wandb.run.log_artifact(artifact)