feat: use LinearWarmupCosineAnnealingLR

Former-commit-id: a7292fe2b0898513fd0e913c2ae352a187f05b12 [formerly 432664f5bd6c3f8f54c221e4d7cc8853d08ea55b]
Former-commit-id: b777bd5053a005ec9b4da7db19a4dbe2a1ba41fd
This commit is contained in:
Laurent Fainsin 2022-09-07 10:44:29 +02:00
parent 291ea632bd
commit e94ea4147c

View file

@ -4,6 +4,7 @@ import pytorch_lightning as pl
import torch import torch
import torchvision import torchvision
import wandb import wandb
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import ( from torchvision.models.detection.mask_rcnn import (
@ -43,8 +44,8 @@ class MRCNNModule(pl.LightningModule):
# Network # Network
self.model = get_model_instance_segmentation(n_classes) self.model = get_model_instance_segmentation(n_classes)
# onnx # onnx export
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True) self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
def forward(self, imgs): def forward(self, imgs):
self.model.eval() self.model.eval()
@ -78,22 +79,9 @@ class MRCNNModule(pl.LightningModule):
target["masks"] = target["masks"].squeeze(1).bool() target["masks"] = target["masks"].squeeze(1).bool()
self.metric.update(preds, targets) self.metric.update(preds, targets)
# compute validation loss return preds
self.model.train()
loss_dict = self.model(images, targets)
loss_dict = {f"valid/{key}": val for key, val in loss_dict.items()}
loss_dict["valid/loss"] = sum(loss_dict.values())
self.model.eval()
return loss_dict
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
# log validation loss
loss_dict = {
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
} # TODO: update un dict object
self.log_dict(loss_dict)
# log metrics # log metrics
metric_dict = self.metric.compute() metric_dict = self.metric.compute()
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()} metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
@ -103,24 +91,20 @@ class MRCNNModule(pl.LightningModule):
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.parameters(), self.parameters(),
lr=wandb.config.LEARNING_RATE, lr=wandb.config.LEARNING_RATE,
# momentum=wandb.config.MOMENTUM, momentum=wandb.config.MOMENTUM,
# weight_decay=wandb.config.WEIGHT_DECAY, weight_decay=wandb.config.WEIGHT_DECAY,
) )
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( scheduler = LinearWarmupCosineAnnealingLR(
# optimizer, optimizer,
# T_0=3, warmup_epochs=10,
# T_mult=1, max_epochs=40,
# lr=wandb.config.LEARNING_RATE_MIN, )
# verbose=True,
# )
# return { return {
# "optimizer": optimizer, "optimizer": optimizer,
# "lr_scheduler": { "lr_scheduler": {
# "scheduler": scheduler, "scheduler": scheduler,
# "monitor": "val_accuracy", "monitor": "map",
# }, },
# } }
return optimizer