From e94ea4147c582661ec9c2fe4d3c3664ee6f73d65 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 7 Sep 2022 10:44:29 +0200 Subject: [PATCH] feat: use LinearWarmupCosineAnnealingLR Former-commit-id: a7292fe2b0898513fd0e913c2ae352a187f05b12 [formerly 432664f5bd6c3f8f54c221e4d7cc8853d08ea55b] Former-commit-id: b777bd5053a005ec9b4da7db19a4dbe2a1ba41fd --- src/mrcnn/module.py | 52 ++++++++++++++++----------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/src/mrcnn/module.py b/src/mrcnn/module.py index 6b93a88..03da906 100644 --- a/src/mrcnn/module.py +++ b/src/mrcnn/module.py @@ -4,6 +4,7 @@ import pytorch_lightning as pl import torch import torchvision import wandb +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import ( @@ -43,8 +44,8 @@ class MRCNNModule(pl.LightningModule): # Network self.model = get_model_instance_segmentation(n_classes) - # onnx - self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True) + # onnx export + self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half() def forward(self, imgs): self.model.eval() @@ -78,22 +79,9 @@ class MRCNNModule(pl.LightningModule): target["masks"] = target["masks"].squeeze(1).bool() self.metric.update(preds, targets) - # compute validation loss - self.model.train() - loss_dict = self.model(images, targets) - loss_dict = {f"valid/{key}": val for key, val in loss_dict.items()} - loss_dict["valid/loss"] = sum(loss_dict.values()) - self.model.eval() - - return loss_dict + return preds def validation_epoch_end(self, outputs): - # log validation loss - loss_dict = { - k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys() - } # TODO: update un dict object - self.log_dict(loss_dict) - # log metrics metric_dict = self.metric.compute() metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()} @@ -103,24 +91,20 @@ class MRCNNModule(pl.LightningModule): optimizer = torch.optim.Adam( self.parameters(), lr=wandb.config.LEARNING_RATE, - # momentum=wandb.config.MOMENTUM, - # weight_decay=wandb.config.WEIGHT_DECAY, + momentum=wandb.config.MOMENTUM, + weight_decay=wandb.config.WEIGHT_DECAY, ) - # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - # optimizer, - # T_0=3, - # T_mult=1, - # lr=wandb.config.LEARNING_RATE_MIN, - # verbose=True, - # ) + scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=10, + max_epochs=40, + ) - # return { - # "optimizer": optimizer, - # "lr_scheduler": { - # "scheduler": scheduler, - # "monitor": "val_accuracy", - # }, - # } - - return optimizer + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "map", + }, + }