From 291ea632bd93c39d7c9a78e2578973d573417ce0 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 7 Sep 2022 10:44:10 +0200 Subject: [PATCH] feat: bunch of new callbacks Former-commit-id: 40fa0eaabd63cc00becb79164255289b18faf26f [formerly 2e813e0fbd9b7316d500b1d3f694b680d1e4e949] Former-commit-id: 8b0372f2a4d3657e0728f52ae12c529df6985a07 --- src/train.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index d0ec56e..6d65d61 100644 --- a/src/train.py +++ b/src/train.py @@ -1,13 +1,21 @@ import logging import pytorch_lightning as pl -import torch import wandb -from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, + ModelPruning, + QuantizationAwareTraining, + RichModelSummary, + RichProgressBar, +) from pytorch_lightning.loggers import WandbLogger from data import Spheres from mrcnn import MRCNNModule +from utils.callback import TableLog if __name__ == "__main__": # setup logging @@ -62,12 +70,17 @@ if __name__ == "__main__": log_every_n_steps=5, val_check_interval=50, callbacks=[ - ModelCheckpoint(monitor="valid/loss", mode="min"), + EarlyStopping(monitor="valid/map", mode="max", patience=10, min_delta=0.01), + ModelCheckpoint(monitor="valid/map", mode="max"), + # ModelPruning("l1_unstructured", amount=0.5), + LearningRateMonitor(log_momentum=True), + RichModelSummary(max_depth=2), RichProgressBar(), + TableLog(), ], # profiler="advanced", num_sanity_val_steps=3, - devices=[1], + devices=[0], ) # actually train the model