mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: bunch of new callbacks
Former-commit-id: 40fa0eaabd63cc00becb79164255289b18faf26f [formerly 2e813e0fbd9b7316d500b1d3f694b680d1e4e949] Former-commit-id: 8b0372f2a4d3657e0728f52ae12c529df6985a07
This commit is contained in:
parent
4a319ac39a
commit
291ea632bd
21
src/train.py
21
src/train.py
|
@ -1,13 +1,21 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import wandb
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
ModelPruning,
|
||||
QuantizationAwareTraining,
|
||||
RichModelSummary,
|
||||
RichProgressBar,
|
||||
)
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
from data import Spheres
|
||||
from mrcnn import MRCNNModule
|
||||
from utils.callback import TableLog
|
||||
|
||||
if __name__ == "__main__":
|
||||
# setup logging
|
||||
|
@ -62,12 +70,17 @@ if __name__ == "__main__":
|
|||
log_every_n_steps=5,
|
||||
val_check_interval=50,
|
||||
callbacks=[
|
||||
ModelCheckpoint(monitor="valid/loss", mode="min"),
|
||||
EarlyStopping(monitor="valid/map", mode="max", patience=10, min_delta=0.01),
|
||||
ModelCheckpoint(monitor="valid/map", mode="max"),
|
||||
# ModelPruning("l1_unstructured", amount=0.5),
|
||||
LearningRateMonitor(log_momentum=True),
|
||||
RichModelSummary(max_depth=2),
|
||||
RichProgressBar(),
|
||||
TableLog(),
|
||||
],
|
||||
# profiler="advanced",
|
||||
num_sanity_val_steps=3,
|
||||
devices=[1],
|
||||
devices=[0],
|
||||
)
|
||||
|
||||
# actually train the model
|
||||
|
|
Loading…
Reference in a new issue