mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: checkpoint wandb logging
feat: wandb config file Former-commit-id: 45f56db86ca269b028cf76bf5315bc0eef8d2a21 [formerly e320b72e16eed02bdca05245e7c77914f0e288f9] Former-commit-id: 6d91318784748308c73dc6a164653f04ae46cd2a
This commit is contained in:
parent
4015dad491
commit
47c888cf6c
|
@ -11,5 +11,5 @@ charset = utf-8
|
||||||
trim_trailing_whitespace = true
|
trim_trailing_whitespace = true
|
||||||
insert_final_newline = true
|
insert_final_newline = true
|
||||||
|
|
||||||
[*.{json,toml}]
|
[*.{json,toml,yaml,yml}]
|
||||||
indent_size = 2
|
indent_size = 2
|
||||||
|
|
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
|
@ -18,7 +18,7 @@
|
||||||
"--model",
|
"--model",
|
||||||
"good.onnx",
|
"good.onnx",
|
||||||
],
|
],
|
||||||
"justMyCode": true
|
"justMyCode": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
41
src/config-defaults.yaml
Normal file
41
src/config-defaults.yaml
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
DIR_TRAIN_IMG:
|
||||||
|
value: "/home/lilian/data_disk/lfainsin/train/"
|
||||||
|
DIR_VALID_IMG:
|
||||||
|
value: "/home/lilian/data_disk/lfainsin/test_batched_fast/"
|
||||||
|
DIR_SPHERE:
|
||||||
|
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
|
||||||
|
|
||||||
|
FEATURES:
|
||||||
|
value: { 8, 16, 32, 64 }
|
||||||
|
N_CHANNELS:
|
||||||
|
value: 3,
|
||||||
|
N_CLASSES:
|
||||||
|
value: 1,
|
||||||
|
|
||||||
|
AMP:
|
||||||
|
value: True
|
||||||
|
PIN_MEMORY:
|
||||||
|
value: True
|
||||||
|
BENCHMARK:
|
||||||
|
value: True
|
||||||
|
DEVICE:
|
||||||
|
value: gpu
|
||||||
|
WORKERS:
|
||||||
|
value: 8
|
||||||
|
|
||||||
|
IMG_SIZE:
|
||||||
|
value: 512
|
||||||
|
SPHERES:
|
||||||
|
value: 5
|
||||||
|
|
||||||
|
EPOCHS:
|
||||||
|
value: 10
|
||||||
|
BATCH_SIZE:
|
||||||
|
value: 16
|
||||||
|
|
||||||
|
LEARNING_RATE:
|
||||||
|
value: 1e-4
|
||||||
|
WEIGHT_DECAY:
|
||||||
|
value: 1e-8
|
||||||
|
MOMENTUM:
|
||||||
|
value: 0.9
|
|
@ -36,19 +36,6 @@ class Spheres(pl.LightningDataModule):
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
# dataset = LabeledDataset(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
|
|
||||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
|
||||||
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
|
||||||
|
|
||||||
# return DataLoader(
|
|
||||||
# dataset,
|
|
||||||
# shuffle=True,
|
|
||||||
# batch_size=8,
|
|
||||||
# prefetch_factor=8,
|
|
||||||
# num_workers=wandb.config.WORKERS,
|
|
||||||
# pin_memory=wandb.config.PIN_MEMORY,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||||
|
@ -56,7 +43,7 @@ class Spheres(pl.LightningDataModule):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=1,
|
batch_size=8,
|
||||||
prefetch_factor=8,
|
prefetch_factor=8,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
|
63
src/train.py
63
src/train.py
|
@ -1,44 +1,21 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from unet import UNetModule
|
from unet import UNetModule
|
||||||
from utils import TableLog
|
from utils import ArtifactLog, TableLog
|
||||||
|
|
||||||
CONFIG = {
|
|
||||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
|
||||||
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/test_batched_fast/",
|
|
||||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real/",
|
|
||||||
"FEATURES": [8, 16, 32, 64],
|
|
||||||
"N_CHANNELS": 3,
|
|
||||||
"N_CLASSES": 1,
|
|
||||||
"AMP": True,
|
|
||||||
"PIN_MEMORY": True,
|
|
||||||
"BENCHMARK": True,
|
|
||||||
"DEVICE": "gpu",
|
|
||||||
"WORKERS": 8,
|
|
||||||
"EPOCHS": 10,
|
|
||||||
"BATCH_SIZE": 16,
|
|
||||||
"LEARNING_RATE": 1e-4,
|
|
||||||
"WEIGHT_DECAY": 1e-8,
|
|
||||||
"MOMENTUM": 0.9,
|
|
||||||
"IMG_SIZE": 512,
|
|
||||||
"SPHERES": 5,
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# setup logging
|
# setup logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb, config loaded from config-default.yaml
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
project="U-Net",
|
project="U-Net",
|
||||||
config=CONFIG,
|
|
||||||
settings=wandb.Settings(
|
settings=wandb.Settings(
|
||||||
code_dir="./src/",
|
code_dir="./src/",
|
||||||
),
|
),
|
||||||
|
@ -49,44 +26,38 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Create network
|
# Create network
|
||||||
model = UNetModule(
|
model = UNetModule(
|
||||||
n_channels=CONFIG["N_CHANNELS"],
|
n_channels=wandb.config.N_CHANNELS,
|
||||||
n_classes=CONFIG["N_CLASSES"],
|
n_classes=wandb.config.N_CLASSES,
|
||||||
batch_size=CONFIG["BATCH_SIZE"],
|
batch_size=wandb.config.BATCH_SIZE,
|
||||||
learning_rate=CONFIG["LEARNING_RATE"],
|
learning_rate=wandb.config.LEARNING_RATE,
|
||||||
features=CONFIG["FEATURES"],
|
features=wandb.config.FEATURES,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
state_dict = torch.load("checkpoints/synth.pth")
|
# state_dict = torch.load("checkpoints/synth.pth")
|
||||||
state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
# state_dict = dict([(f"model.{key}", value) for key, value in state_dict.items()])
|
||||||
model.load_state_dict(state_dict)
|
# model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(model, log="all")
|
logger.watch(model, log="all")
|
||||||
|
|
||||||
# create checkpoint callback
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
|
||||||
dirpath="checkpoints",
|
|
||||||
filename="model.ckpt",
|
|
||||||
monitor="val/dice",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the dataloaders
|
# Create the dataloaders
|
||||||
datamodule = Spheres()
|
datamodule = Spheres()
|
||||||
|
|
||||||
# Create the trainer
|
# Create the trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=CONFIG["EPOCHS"],
|
max_epochs=wandb.config.EPOCHS,
|
||||||
accelerator=CONFIG["DEVICE"],
|
accelerator=wandb.config.DEVICE,
|
||||||
benchmark=CONFIG["BENCHMARK"],
|
benchmark=wandb.config.BENCHMARK,
|
||||||
# profiler="simple",
|
# profiler="simple",
|
||||||
# precision=16,
|
# precision=16,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=1,
|
||||||
val_check_interval=25,
|
val_check_interval=100,
|
||||||
callbacks=[RichProgressBar(), checkpoint_callback, TableLog()],
|
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# actually train the model
|
||||||
trainer.fit(model=model, datamodule=datamodule)
|
trainer.fit(model=model, datamodule=datamodule)
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .callback import TableLog
|
from .callback import ArtifactLog, TableLog
|
||||||
from .paste import RandomPaste
|
from .paste import RandomPaste
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from pytorch_lightning.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback
|
||||||
from torch import tensor
|
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
|
@ -22,6 +23,7 @@ class TableLog(Callback):
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
|
if batch_idx == 0:
|
||||||
images, ground_truth = batch
|
images, ground_truth = batch
|
||||||
metrics, predictions = outputs
|
metrics, predictions = outputs
|
||||||
|
|
||||||
|
@ -62,3 +64,30 @@ class TableLog(Callback):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ArtifactLog(Callback):
|
||||||
|
def on_validation_epoch_start(self, trainer, pl_module):
|
||||||
|
self.dices = []
|
||||||
|
self.best = 1
|
||||||
|
|
||||||
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
# unpacking
|
||||||
|
metrics, _ = outputs
|
||||||
|
self.dices.append(metrics["dice"].cpu())
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
|
dice = np.mean(self.dices)
|
||||||
|
self.dices = []
|
||||||
|
|
||||||
|
if dice < self.best:
|
||||||
|
self.best = dice
|
||||||
|
|
||||||
|
# create checkpoint
|
||||||
|
torch.save(self.state_dict(), "checkpoints/model.pth")
|
||||||
|
# trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt
|
||||||
|
|
||||||
|
# create and log artifact
|
||||||
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
|
artifact.add_file("checkpoints/model.pth")
|
||||||
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
Loading…
Reference in a new issue