feat: bunch of new callbacks

Former-commit-id: 40fa0eaabd63cc00becb79164255289b18faf26f [formerly 2e813e0fbd9b7316d500b1d3f694b680d1e4e949]
Former-commit-id: 8b0372f2a4d3657e0728f52ae12c529df6985a07
This commit is contained in:
Laurent Fainsin 2022-09-07 10:44:10 +02:00
parent 4a319ac39a
commit 291ea632bd

View file

@ -1,13 +1,21 @@
import logging import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
import wandb import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
ModelPruning,
QuantizationAwareTraining,
RichModelSummary,
RichProgressBar,
)
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from data import Spheres from data import Spheres
from mrcnn import MRCNNModule from mrcnn import MRCNNModule
from utils.callback import TableLog
if __name__ == "__main__": if __name__ == "__main__":
# setup logging # setup logging
@ -62,12 +70,17 @@ if __name__ == "__main__":
log_every_n_steps=5, log_every_n_steps=5,
val_check_interval=50, val_check_interval=50,
callbacks=[ callbacks=[
ModelCheckpoint(monitor="valid/loss", mode="min"), EarlyStopping(monitor="valid/map", mode="max", patience=10, min_delta=0.01),
ModelCheckpoint(monitor="valid/map", mode="max"),
# ModelPruning("l1_unstructured", amount=0.5),
LearningRateMonitor(log_momentum=True),
RichModelSummary(max_depth=2),
RichProgressBar(), RichProgressBar(),
TableLog(),
], ],
# profiler="advanced", # profiler="advanced",
num_sanity_val_steps=3, num_sanity_val_steps=3,
devices=[1], devices=[0],
) )
# actually train the model # actually train the model