feat: use LinearWarmupCosineAnnealingLR

Former-commit-id: a7292fe2b0898513fd0e913c2ae352a187f05b12 [formerly 432664f5bd6c3f8f54c221e4d7cc8853d08ea55b]
Former-commit-id: b777bd5053a005ec9b4da7db19a4dbe2a1ba41fd
This commit is contained in:
Laurent Fainsin 2022-09-07 10:44:29 +02:00
parent 291ea632bd
commit e94ea4147c

View file

@ -4,6 +4,7 @@ import pytorch_lightning as pl
import torch
import torchvision
import wandb
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import (
@ -43,8 +44,8 @@ class MRCNNModule(pl.LightningModule):
# Network
self.model = get_model_instance_segmentation(n_classes)
# onnx
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True)
# onnx export
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
def forward(self, imgs):
self.model.eval()
@ -78,22 +79,9 @@ class MRCNNModule(pl.LightningModule):
target["masks"] = target["masks"].squeeze(1).bool()
self.metric.update(preds, targets)
# compute validation loss
self.model.train()
loss_dict = self.model(images, targets)
loss_dict = {f"valid/{key}": val for key, val in loss_dict.items()}
loss_dict["valid/loss"] = sum(loss_dict.values())
self.model.eval()
return loss_dict
return preds
def validation_epoch_end(self, outputs):
# log validation loss
loss_dict = {
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
} # TODO: update un dict object
self.log_dict(loss_dict)
# log metrics
metric_dict = self.metric.compute()
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
@ -103,24 +91,20 @@ class MRCNNModule(pl.LightningModule):
optimizer = torch.optim.Adam(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
# momentum=wandb.config.MOMENTUM,
# weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
weight_decay=wandb.config.WEIGHT_DECAY,
)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
# optimizer,
# T_0=3,
# T_mult=1,
# lr=wandb.config.LEARNING_RATE_MIN,
# verbose=True,
# )
scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=10,
max_epochs=40,
)
# return {
# "optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": scheduler,
# "monitor": "val_accuracy",
# },
# }
return optimizer
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "map",
},
}