feat: checkpoint wandb logging
feat: wandb config file Former-commit-id: 45f56db86ca269b028cf76bf5315bc0eef8d2a21 [formerly e320b72e16eed02bdca05245e7c77914f0e288f9] Former-commit-id: 6d91318784748308c73dc6a164653f04ae46cd2a
This commit is contained in:
parent
4015dad491
commit
47c888cf6c
|
@ -11,5 +11,5 @@ charset = utf-8
|
|||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.{json,toml}]
|
||||
[*.{json,toml,yaml,yml}]
|
||||
indent_size = 2
|
||||
|
|
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
|
@ -18,7 +18,7 @@
|
|||
"--model",
|
||||
"good.onnx",
|
||||
],
|
||||
"justMyCode": true
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
41
src/config-defaults.yaml
Normal file
41
src/config-defaults.yaml
Normal file
|
@ -0,0 +1,41 @@
|
|||
DIR_TRAIN_IMG:
|
||||
value: "/home/lilian/data_disk/lfainsin/train/"
|
||||
DIR_VALID_IMG:
|
||||
value: "/home/lilian/data_disk/lfainsin/test_batched_fast/"
|
||||
DIR_SPHERE:
|
||||
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
|
||||
|
||||
FEATURES:
|
||||
value: { 8, 16, 32, 64 }
|
||||
N_CHANNELS:
|
||||
value: 3,
|
||||
N_CLASSES:
|
||||
value: 1,
|
||||
|
||||
AMP:
|
||||
value: True
|
||||
PIN_MEMORY:
|
||||
value: True
|
||||
BENCHMARK:
|
||||
value: True
|
||||
DEVICE:
|
||||
value: gpu
|
||||
WORKERS:
|
||||
value: 8
|
||||
|
||||
IMG_SIZE:
|
||||
value: 512
|
||||
SPHERES:
|
||||
value: 5
|
||||
|
||||
EPOCHS:
|
||||
value: 10
|
||||
BATCH_SIZE:
|
||||
value: 16
|
||||
|
||||
LEARNING_RATE:
|
||||
value: 1e-4
|
||||
WEIGHT_DECAY:
|
||||
value: 1e-8
|
||||
MOMENTUM:
|
||||
value: 0.9
|
|
@ -36,19 +36,6 @@ class Spheres(pl.LightningDataModule):
|
|||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
|
||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
|
||||
# return DataLoader(
|
||||
# dataset,
|
||||
# shuffle=True,
|
||||
# batch_size=8,
|
||||
# prefetch_factor=8,
|
||||
# num_workers=wandb.config.WORKERS,
|
||||
# pin_memory=wandb.config.PIN_MEMORY,
|
||||
# )
|
||||
|
||||
def val_dataloader(self):
|
||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
|
@ -56,7 +43,7 @@ class Spheres(pl.LightningDataModule):
|
|||
return DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
batch_size=8,
|
||||
prefetch_factor=8,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
|
|
63
src/train.py
63
src/train.py
|
@ -1,44 +1,21 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
import wandb
|
||||
from data import Spheres
|
||||
from unet import UNetModule
|
||||
from utils import TableLog
|
||||
|
||||
CONFIG = {
|
||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
|
||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
|
||||
"FEATURES": [8, 16, 32, 64],
|
||||
"N_CHANNELS": 3,
|
||||
"N_CLASSES": 1,
|
||||
"AMP": True,
|
||||
"PIN_MEMORY": True,
|
||||
"BENCHMARK": True,
|
||||
"DEVICE": "gpu",
|
||||
"WORKERS": 8,
|
||||
"EPOCHS": 10,
|
||||
"BATCH_SIZE": 16,
|
||||
"LEARNING_RATE": 1e-4,
|
||||
"WEIGHT_DECAY": 1e-8,
|
||||
"MOMENTUM": 0.9,
|
||||
"IMG_SIZE": 512,
|
||||
"SPHERES": 5,
|
||||
}
|
||||
from utils import ArtifactLog, TableLog
|
||||
|
||||
if __name__ == "__main__":
|
||||
# setup logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
# setup wandb
|
||||
# setup wandb, config loaded from config-default.yaml
|
||||
logger = WandbLogger(
|
||||
project="U-Net",
|
||||
config=CONFIG,
|
||||
settings=wandb.Settings(
|
||||
code_dir="./src/",
|
||||
),
|
||||
|
@ -49,44 +26,38 @@ if __name__ == "__main__":
|
|||
|
||||
# Create network
|
||||
model = UNetModule(
|
||||
n_channels=CONFIG["N_CHANNELS"],
|
||||
n_classes=CONFIG["N_CLASSES"],
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
learning_rate=CONFIG["LEARNING_RATE"],
|
||||
features=CONFIG["FEATURES"],
|
||||
n_channels=wandb.config.N_CHANNELS,
|
||||
n_classes=wandb.config.N_CLASSES,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
learning_rate=wandb.config.LEARNING_RATE,
|
||||
features=wandb.config.FEATURES,
|
||||
)
|
||||
|
||||
# load checkpoint
|
||||
state_dict = torch.load("checkpoints/synth.pth")
|
||||
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
||||
model.load_state_dict(state_dict)
|
||||
# state_dict = torch.load("checkpoints/synth.pth")
|
||||
# state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
||||
# model.load_state_dict(state_dict)
|
||||
|
||||
# log gradients and weights regularly
|
||||
logger.watch(model, log="all")
|
||||
|
||||
# create checkpoint callback
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath="checkpoints",
|
||||
filename="model.ckpt",
|
||||
monitor="val/dice",
|
||||
)
|
||||
|
||||
# Create the dataloaders
|
||||
datamodule = Spheres()
|
||||
|
||||
# Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=CONFIG["EPOCHS"],
|
||||
accelerator=CONFIG["DEVICE"],
|
||||
benchmark=CONFIG["BENCHMARK"],
|
||||
max_epochs=wandb.config.EPOCHS,
|
||||
accelerator=wandb.config.DEVICE,
|
||||
benchmark=wandb.config.BENCHMARK,
|
||||
# profiler="simple",
|
||||
# precision=16,
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
val_check_interval=25,
|
||||
callbacks=[RichProgressBar(), checkpoint_callback, TableLog()],
|
||||
val_check_interval=100,
|
||||
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||
)
|
||||
|
||||
# actually train the model
|
||||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
# stop wandb
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
from .callback import TableLog
|
||||
from .callback import ArtifactLog, TableLog
|
||||
from .paste import RandomPaste
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from torch import tensor
|
||||
|
||||
import wandb
|
||||
|
||||
|
@ -22,35 +23,36 @@ class TableLog(Callback):
|
|||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# unpacking
|
||||
images, ground_truth = batch
|
||||
metrics, predictions = outputs
|
||||
if batch_idx == 0:
|
||||
images, ground_truth = batch
|
||||
metrics, predictions = outputs
|
||||
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
ground_truth.cpu(),
|
||||
predictions["linear"].cpu(),
|
||||
predictions["binary"].cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
self.rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
ground_truth.cpu(),
|
||||
predictions["linear"].cpu(),
|
||||
predictions["binary"].cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
self.rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
),
|
||||
metrics["dice"],
|
||||
metrics["dice_bin"],
|
||||
]
|
||||
)
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
# log table
|
||||
|
@ -62,3 +64,30 @@ class TableLog(Callback):
|
|||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ArtifactLog(Callback):
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
self.dices = []
|
||||
self.best = 1
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
# unpacking
|
||||
metrics, _ = outputs
|
||||
self.dices.append(metrics["dice"].cpu())
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
dice = np.mean(self.dices)
|
||||
self.dices = []
|
||||
|
||||
if dice < self.best:
|
||||
self.best = dice
|
||||
|
||||
# create checkpoint
|
||||
torch.save(self.state_dict(), "checkpoints/model.pth")
|
||||
# trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt
|
||||
|
||||
# create and log artifact
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file("checkpoints/model.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
|
Loading…
Reference in a new issue