mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: bunch of new callbacks
Former-commit-id: 40fa0eaabd63cc00becb79164255289b18faf26f [formerly 2e813e0fbd9b7316d500b1d3f694b680d1e4e949] Former-commit-id: 8b0372f2a4d3657e0728f52ae12c529df6985a07
This commit is contained in:
parent
4a319ac39a
commit
291ea632bd
21
src/train.py
21
src/train.py
|
@ -1,13 +1,21 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
|
||||||
import wandb
|
import wandb
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
from pytorch_lightning.callbacks import (
|
||||||
|
EarlyStopping,
|
||||||
|
LearningRateMonitor,
|
||||||
|
ModelCheckpoint,
|
||||||
|
ModelPruning,
|
||||||
|
QuantizationAwareTraining,
|
||||||
|
RichModelSummary,
|
||||||
|
RichProgressBar,
|
||||||
|
)
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from mrcnn import MRCNNModule
|
from mrcnn import MRCNNModule
|
||||||
|
from utils.callback import TableLog
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# setup logging
|
# setup logging
|
||||||
|
@ -62,12 +70,17 @@ if __name__ == "__main__":
|
||||||
log_every_n_steps=5,
|
log_every_n_steps=5,
|
||||||
val_check_interval=50,
|
val_check_interval=50,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
ModelCheckpoint(monitor="valid/loss", mode="min"),
|
EarlyStopping(monitor="valid/map", mode="max", patience=10, min_delta=0.01),
|
||||||
|
ModelCheckpoint(monitor="valid/map", mode="max"),
|
||||||
|
# ModelPruning("l1_unstructured", amount=0.5),
|
||||||
|
LearningRateMonitor(log_momentum=True),
|
||||||
|
RichModelSummary(max_depth=2),
|
||||||
RichProgressBar(),
|
RichProgressBar(),
|
||||||
|
TableLog(),
|
||||||
],
|
],
|
||||||
# profiler="advanced",
|
# profiler="advanced",
|
||||||
num_sanity_val_steps=3,
|
num_sanity_val_steps=3,
|
||||||
devices=[1],
|
devices=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# actually train the model
|
# actually train the model
|
||||||
|
|
Loading…
Reference in a new issue