diff --git a/src/train.py b/src/train.py index d0ec56e..6d65d61 100644 --- a/src/train.py +++ b/src/train.py @@ -1,13 +1,21 @@ import logging import pytorch_lightning as pl -import torch import wandb -from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar +from pytorch_lightning.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, + ModelPruning, + QuantizationAwareTraining, + RichModelSummary, + RichProgressBar, +) from pytorch_lightning.loggers import WandbLogger from data import Spheres from mrcnn import MRCNNModule +from utils.callback import TableLog if __name__ == "__main__": # setup logging @@ -62,12 +70,17 @@ if __name__ == "__main__": log_every_n_steps=5, val_check_interval=50, callbacks=[ - ModelCheckpoint(monitor="valid/loss", mode="min"), + EarlyStopping(monitor="valid/map", mode="max", patience=10, min_delta=0.01), + ModelCheckpoint(monitor="valid/map", mode="max"), + # ModelPruning("l1_unstructured", amount=0.5), + LearningRateMonitor(log_momentum=True), + RichModelSummary(max_depth=2), RichProgressBar(), + TableLog(), ], # profiler="advanced", num_sanity_val_steps=3, - devices=[1], + devices=[0], ) # actually train the model