feat: use LinearWarmupCosineAnnealingLR
Former-commit-id: a7292fe2b0898513fd0e913c2ae352a187f05b12 [formerly 432664f5bd6c3f8f54c221e4d7cc8853d08ea55b] Former-commit-id: b777bd5053a005ec9b4da7db19a4dbe2a1ba41fd
This commit is contained in:
parent
291ea632bd
commit
e94ea4147c
|
@ -4,6 +4,7 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import wandb
|
import wandb
|
||||||
|
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
||||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.mask_rcnn import (
|
from torchvision.models.detection.mask_rcnn import (
|
||||||
|
@ -43,8 +44,8 @@ class MRCNNModule(pl.LightningModule):
|
||||||
# Network
|
# Network
|
||||||
self.model = get_model_instance_segmentation(n_classes)
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
# onnx
|
# onnx export
|
||||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True)
|
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||||
|
|
||||||
def forward(self, imgs):
|
def forward(self, imgs):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -78,22 +79,9 @@ class MRCNNModule(pl.LightningModule):
|
||||||
target["masks"] = target["masks"].squeeze(1).bool()
|
target["masks"] = target["masks"].squeeze(1).bool()
|
||||||
self.metric.update(preds, targets)
|
self.metric.update(preds, targets)
|
||||||
|
|
||||||
# compute validation loss
|
return preds
|
||||||
self.model.train()
|
|
||||||
loss_dict = self.model(images, targets)
|
|
||||||
loss_dict = {f"valid/{key}": val for key, val in loss_dict.items()}
|
|
||||||
loss_dict["valid/loss"] = sum(loss_dict.values())
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
return loss_dict
|
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
def validation_epoch_end(self, outputs):
|
||||||
# log validation loss
|
|
||||||
loss_dict = {
|
|
||||||
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
|
|
||||||
} # TODO: update un dict object
|
|
||||||
self.log_dict(loss_dict)
|
|
||||||
|
|
||||||
# log metrics
|
# log metrics
|
||||||
metric_dict = self.metric.compute()
|
metric_dict = self.metric.compute()
|
||||||
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
|
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
|
||||||
|
@ -103,24 +91,20 @@ class MRCNNModule(pl.LightningModule):
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
lr=wandb.config.LEARNING_RATE,
|
lr=wandb.config.LEARNING_RATE,
|
||||||
# momentum=wandb.config.MOMENTUM,
|
momentum=wandb.config.MOMENTUM,
|
||||||
# weight_decay=wandb.config.WEIGHT_DECAY,
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
)
|
)
|
||||||
|
|
||||||
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
scheduler = LinearWarmupCosineAnnealingLR(
|
||||||
# optimizer,
|
optimizer,
|
||||||
# T_0=3,
|
warmup_epochs=10,
|
||||||
# T_mult=1,
|
max_epochs=40,
|
||||||
# lr=wandb.config.LEARNING_RATE_MIN,
|
)
|
||||||
# verbose=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return {
|
return {
|
||||||
# "optimizer": optimizer,
|
"optimizer": optimizer,
|
||||||
# "lr_scheduler": {
|
"lr_scheduler": {
|
||||||
# "scheduler": scheduler,
|
"scheduler": scheduler,
|
||||||
# "monitor": "val_accuracy",
|
"monitor": "map",
|
||||||
# },
|
},
|
||||||
# }
|
}
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
Loading…
Reference in a new issue