feat: checkpoint wandb logging

feat: wandb config file

Former-commit-id: 45f56db86ca269b028cf76bf5315bc0eef8d2a21 [formerly e320b72e16eed02bdca05245e7c77914f0e288f9]
Former-commit-id: 6d91318784748308c73dc6a164653f04ae46cd2a
This commit is contained in:
Laurent Fainsin 2022-07-11 15:34:05 +02:00
parent 4015dad491
commit 47c888cf6c
7 changed files with 119 additions and 91 deletions

View file

@ -11,5 +11,5 @@ charset = utf-8
trim_trailing_whitespace = true trim_trailing_whitespace = true
insert_final_newline = true insert_final_newline = true
[*.{json,toml}] [*.{json,toml,yaml,yml}]
indent_size = 2 indent_size = 2

2
.vscode/launch.json vendored
View file

@ -18,7 +18,7 @@
"--model", "--model",
"good.onnx", "good.onnx",
], ],
"justMyCode": true "justMyCode": false
} }
] ]
} }

41
src/config-defaults.yaml Normal file
View file

@ -0,0 +1,41 @@
DIR_TRAIN_IMG:
value: "/home/lilian/data_disk/lfainsin/train/"
DIR_VALID_IMG:
value: "/home/lilian/data_disk/lfainsin/test_batched_fast/"
DIR_SPHERE:
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
FEATURES:
value: { 8, 16, 32, 64 }
N_CHANNELS:
value: 3,
N_CLASSES:
value: 1,
AMP:
value: True
PIN_MEMORY:
value: True
BENCHMARK:
value: True
DEVICE:
value: gpu
WORKERS:
value: 8
IMG_SIZE:
value: 512
SPHERES:
value: 5
EPOCHS:
value: 10
BATCH_SIZE:
value: 16
LEARNING_RATE:
value: 1e-4
WEIGHT_DECAY:
value: 1e-8
MOMENTUM:
value: 0.9

View file

@ -36,19 +36,6 @@ class Spheres(pl.LightningDataModule):
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
) )
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
# return DataLoader(
# dataset,
# shuffle=True,
# batch_size=8,
# prefetch_factor=8,
# num_workers=wandb.config.WORKERS,
# pin_memory=wandb.config.PIN_MEMORY,
# )
def val_dataloader(self): def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
@ -56,7 +43,7 @@ class Spheres(pl.LightningDataModule):
return DataLoader( return DataLoader(
dataset, dataset,
shuffle=False, shuffle=False,
batch_size=1, batch_size=8,
prefetch_factor=8, prefetch_factor=8,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,

View file

@ -1,44 +1,21 @@
import logging import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
import wandb import wandb
from data import Spheres from data import Spheres
from unet import UNetModule from unet import UNetModule
from utils import TableLog from utils import ArtifactLog, TableLog
CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
"FEATURES": [8, 16, 32, 64],
"N_CHANNELS": 3,
"N_CLASSES": 1,
"AMP": True,
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 8,
"EPOCHS": 10,
"BATCH_SIZE": 16,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
"IMG_SIZE": 512,
"SPHERES": 5,
}
if __name__ == "__main__": if __name__ == "__main__":
# setup logging # setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb # setup wandb, config loaded from config-default.yaml
logger = WandbLogger( logger = WandbLogger(
project="U-Net", project="U-Net",
config=CONFIG,
settings=wandb.Settings( settings=wandb.Settings(
code_dir="./src/", code_dir="./src/",
), ),
@ -49,44 +26,38 @@ if __name__ == "__main__":
# Create network # Create network
model = UNetModule( model = UNetModule(
n_channels=CONFIG["N_CHANNELS"], n_channels=wandb.config.N_CHANNELS,
n_classes=CONFIG["N_CLASSES"], n_classes=wandb.config.N_CLASSES,
batch_size=CONFIG["BATCH_SIZE"], batch_size=wandb.config.BATCH_SIZE,
learning_rate=CONFIG["LEARNING_RATE"], learning_rate=wandb.config.LEARNING_RATE,
features=CONFIG["FEATURES"], features=wandb.config.FEATURES,
) )
# load checkpoint # load checkpoint
state_dict = torch.load("checkpoints/synth.pth") # state_dict = torch.load("checkpoints/synth.pth")
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()]) # state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
model.load_state_dict(state_dict) # model.load_state_dict(state_dict)
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(model, log="all") logger.watch(model, log="all")
# create checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="model.ckpt",
monitor="val/dice",
)
# Create the dataloaders # Create the dataloaders
datamodule = Spheres() datamodule = Spheres()
# Create the trainer # Create the trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"], max_epochs=wandb.config.EPOCHS,
accelerator=CONFIG["DEVICE"], accelerator=wandb.config.DEVICE,
benchmark=CONFIG["BENCHMARK"], benchmark=wandb.config.BENCHMARK,
# profiler="simple", # profiler="simple",
# precision=16, # precision=16,
logger=logger, logger=logger,
log_every_n_steps=1, log_every_n_steps=1,
val_check_interval=25, val_check_interval=100,
callbacks=[RichProgressBar(), checkpoint_callback, TableLog()], callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
) )
# actually train the model
trainer.fit(model=model, datamodule=datamodule) trainer.fit(model=model, datamodule=datamodule)
# stop wandb # stop wandb

View file

@ -1,2 +1,2 @@
from .callback import TableLog from .callback import ArtifactLog, TableLog
from .paste import RandomPaste from .paste import RandomPaste

View file

@ -1,5 +1,6 @@
import numpy as np
import torch
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from torch import tensor
import wandb import wandb
@ -22,35 +23,36 @@ class TableLog(Callback):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking # unpacking
images, ground_truth = batch if batch_idx == 0:
metrics, predictions = outputs images, ground_truth = batch
metrics, predictions = outputs
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( zip(
images.cpu(), images.cpu(),
ground_truth.cpu(), ground_truth.cpu(),
predictions["linear"].cpu(), predictions["linear"].cpu(),
predictions["binary"].cpu().squeeze(1).int().numpy(), predictions["binary"].cpu().squeeze(1).int().numpy(),
) )
): ):
self.rows.append( self.rows.append(
[ [
i, i,
wandb.Image(img), wandb.Image(img),
wandb.Image(mask), wandb.Image(mask),
wandb.Image( wandb.Image(
pred, pred,
masks={ masks={
"predictions": { "predictions": {
"mask_data": pred_bin, "mask_data": pred_bin,
"class_labels": class_labels, "class_labels": class_labels,
},
}, },
}, ),
), metrics["dice"],
metrics["dice"], metrics["dice_bin"],
metrics["dice_bin"], ]
] )
)
def on_validation_epoch_end(self, trainer, pl_module): def on_validation_epoch_end(self, trainer, pl_module):
# log table # log table
@ -62,3 +64,30 @@ class TableLog(Callback):
) )
} }
) )
class ArtifactLog(Callback):
def on_validation_epoch_start(self, trainer, pl_module):
self.dices = []
self.best = 1
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking
metrics, _ = outputs
self.dices.append(metrics["dice"].cpu())
def on_validation_epoch_end(self, trainer, pl_module):
dice = np.mean(self.dices)
self.dices = []
if dice < self.best:
self.best = dice
# create checkpoint
torch.save(self.state_dict(), "checkpoints/model.pth")
# trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt
# create and log artifact
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("checkpoints/model.pth")
wandb.run.log_artifact(artifact)