feat: use LinearWarmupCosineAnnealingLR
Former-commit-id: a7292fe2b0898513fd0e913c2ae352a187f05b12 [formerly 432664f5bd6c3f8f54c221e4d7cc8853d08ea55b] Former-commit-id: b777bd5053a005ec9b4da7db19a4dbe2a1ba41fd
This commit is contained in:
parent
291ea632bd
commit
e94ea4147c
|
@ -4,6 +4,7 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
import torchvision
|
||||
import wandb
|
||||
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import (
|
||||
|
@ -43,8 +44,8 @@ class MRCNNModule(pl.LightningModule):
|
|||
# Network
|
||||
self.model = get_model_instance_segmentation(n_classes)
|
||||
|
||||
# onnx
|
||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True)
|
||||
# onnx export
|
||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||
|
||||
def forward(self, imgs):
|
||||
self.model.eval()
|
||||
|
@ -78,22 +79,9 @@ class MRCNNModule(pl.LightningModule):
|
|||
target["masks"] = target["masks"].squeeze(1).bool()
|
||||
self.metric.update(preds, targets)
|
||||
|
||||
# compute validation loss
|
||||
self.model.train()
|
||||
loss_dict = self.model(images, targets)
|
||||
loss_dict = {f"valid/{key}": val for key, val in loss_dict.items()}
|
||||
loss_dict["valid/loss"] = sum(loss_dict.values())
|
||||
self.model.eval()
|
||||
|
||||
return loss_dict
|
||||
return preds
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
# log validation loss
|
||||
loss_dict = {
|
||||
k: torch.stack([d[k] for d in outputs]).mean() for k in outputs[0].keys()
|
||||
} # TODO: update un dict object
|
||||
self.log_dict(loss_dict)
|
||||
|
||||
# log metrics
|
||||
metric_dict = self.metric.compute()
|
||||
metric_dict = {f"valid/{key}": val for key, val in metric_dict.items()}
|
||||
|
@ -103,24 +91,20 @@ class MRCNNModule(pl.LightningModule):
|
|||
optimizer = torch.optim.Adam(
|
||||
self.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
# momentum=wandb.config.MOMENTUM,
|
||||
# weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
)
|
||||
|
||||
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
# optimizer,
|
||||
# T_0=3,
|
||||
# T_mult=1,
|
||||
# lr=wandb.config.LEARNING_RATE_MIN,
|
||||
# verbose=True,
|
||||
# )
|
||||
scheduler = LinearWarmupCosineAnnealingLR(
|
||||
optimizer,
|
||||
warmup_epochs=10,
|
||||
max_epochs=40,
|
||||
)
|
||||
|
||||
# return {
|
||||
# "optimizer": optimizer,
|
||||
# "lr_scheduler": {
|
||||
# "scheduler": scheduler,
|
||||
# "monitor": "val_accuracy",
|
||||
# },
|
||||
# }
|
||||
|
||||
return optimizer
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"monitor": "map",
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue