mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-09-19 19:45:28 +00:00
feat: basic validation
Former-commit-id: 9d6ac1da57e613d3fa4db7ba68560509611c2b79 [formerly f588a5a0781e7375baa6ce5a4723a7b0d557316a] Former-commit-id: 7415ee71e570da13f9de361131156ba0d2131d8a
This commit is contained in:
parent
04ddf75dd8
commit
562ef110af
1911
poetry.lock
generated
1911
poetry.lock
generated
File diff suppressed because it is too large
Load diff
2
poetry.toml
Normal file
2
poetry.toml
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
[virtualenvs]
|
||||||
|
in-project = true
|
|
@ -11,13 +11,14 @@ numpy = "^1.23.0"
|
||||||
onnx = "^1.12.0"
|
onnx = "^1.12.0"
|
||||||
onnxruntime = "^1.11.1"
|
onnxruntime = "^1.11.1"
|
||||||
python = ">=3.8,<3.11"
|
python = ">=3.8,<3.11"
|
||||||
pytorch-lightning = "^1.6.4"
|
pytorch-lightning = "^1.7.3"
|
||||||
rich = "^12.4.4"
|
rich = "^12.4.4"
|
||||||
scipy = "^1.8.1"
|
scipy = "^1.8.1"
|
||||||
torch = "1.11.0"
|
torch = "^1.12.1"
|
||||||
torchvision = "^0.12.0"
|
torchvision = "^0.13.1"
|
||||||
tqdm = "^4.64.0"
|
tqdm = "^4.64.0"
|
||||||
wandb = "^0.12.19"
|
wandb = "^0.12.19"
|
||||||
|
pycocotools = "^2.0.4"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
|
|
|
@ -36,8 +36,8 @@ class Spheres(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
|
dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
|
||||||
dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
# dataset = Subset(dataset, list(range(len(dataset)))) # somehow this sometimes allows to better utilize the gpu
|
||||||
# dataset = Subset(dataset, list(range(20))) # somehow this allows to better utilize the gpu
|
# dataset = Subset(dataset, list(range(20)))
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -50,15 +50,36 @@ class Spheres(pl.LightningDataModule):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
# dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
transforms = A.Compose(
|
||||||
# dataset = Subset(dataset, list(range(len(dataset)))) # somehow this allows to better utilize the gpu
|
[
|
||||||
|
A.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225],
|
||||||
|
max_pixel_value=255,
|
||||||
|
), # [0, 255] -> [0.0, 1.0] normalized
|
||||||
|
# A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(), # HWC -> CHW
|
||||||
|
],
|
||||||
|
bbox_params=A.BboxParams(
|
||||||
|
format="pascal_voc",
|
||||||
|
min_area=0.0,
|
||||||
|
min_visibility=0.0,
|
||||||
|
label_fields=["labels"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# return DataLoader(
|
dataset = RealDataset(root="/dev/shm/TEST_tmp_mrcnn/", transforms=transforms)
|
||||||
# dataset,
|
# dataset = Subset(dataset, list(range(len(dataset)))) # somehow this sometimes allows to better utilize the gpu
|
||||||
# shuffle=False,
|
dataset = Subset(dataset, list(range(10)))
|
||||||
# prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
|
||||||
# batch_size=wandb.config.VAL_BATCH_SIZE,
|
return DataLoader(
|
||||||
# num_workers=wandb.config.WORKERS,
|
dataset,
|
||||||
# pin_memory=wandb.config.PIN_MEMORY,
|
shuffle=False,
|
||||||
# )
|
persistent_workers=True,
|
||||||
|
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||||
|
batch_size=wandb.config.VALID_BATCH_SIZE,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
|
@ -58,6 +58,7 @@ class RealDataset(Dataset):
|
||||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
||||||
labels = torch.ones((num_objs,), dtype=torch.int64) # suppose there is only one class (id=1)
|
labels = torch.ones((num_objs,), dtype=torch.int64) # suppose there is only one class (id=1)
|
||||||
masks = [mask for mask in masks] # albumentations wants list of masks
|
masks = [mask for mask in masks] # albumentations wants list of masks
|
||||||
|
# TODO: use masks = list(np.asarray(target["masks"])))
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
# arrange transform data
|
# arrange transform data
|
||||||
|
|
|
@ -4,14 +4,19 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
from torchvision.models.detection.mask_rcnn import (
|
||||||
|
MaskRCNN_ResNet50_FPN_Weights,
|
||||||
|
MaskRCNNPredictor,
|
||||||
|
)
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
from utils.coco_eval import CocoEvaluator
|
||||||
|
from utils.coco_utils import get_coco_api_from_dataset, get_iou_types
|
||||||
|
|
||||||
|
|
||||||
def get_model_instance_segmentation(num_classes):
|
def get_model_instance_segmentation(num_classes):
|
||||||
# load an instance segmentation model pre-trained on COCO
|
# load an instance segmentation model pre-trained on COCO
|
||||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # TODO: tester v2
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
|
||||||
|
|
||||||
# get number of input features for the classifier
|
# get number of input features for the classifier
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
@ -41,87 +46,83 @@ class MRCNNModule(pl.LightningModule):
|
||||||
# Network
|
# Network
|
||||||
self.model = get_model_instance_segmentation(n_classes)
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
# def forward(self, imgs):
|
# pycoco evaluator
|
||||||
# # Torchvision FasterRCNN returns the loss during training
|
self.coco = None
|
||||||
# # and the boxes during eval
|
self.iou_types = get_iou_types(self.model)
|
||||||
# self.model.eval()
|
self.coco_evaluator = None
|
||||||
# return self.model(imgs)
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# unpack batch
|
# unpack batch
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
# enable train mode
|
# compute loss
|
||||||
# self.model.train()
|
|
||||||
|
|
||||||
# fasterrcnn takes both images and targets for training
|
|
||||||
loss_dict = self.model(images, targets)
|
loss_dict = self.model(images, targets)
|
||||||
|
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
||||||
loss = sum(loss_dict.values())
|
loss = sum(loss_dict.values())
|
||||||
|
|
||||||
# log everything
|
# log everything
|
||||||
self.log_dict(loss_dict)
|
self.log_dict(loss_dict)
|
||||||
self.log("train/loss", loss)
|
self.log("train/loss", loss)
|
||||||
|
|
||||||
return {"loss": loss, "log": loss_dict}
|
return loss
|
||||||
|
|
||||||
# def validation_step(self, batch, batch_idx):
|
def on_validation_epoch_start(self):
|
||||||
# # unpack batch
|
if self.coco is None:
|
||||||
# images, targets = batch
|
self.coco = get_coco_api_from_dataset(self.trainer.val_dataloaders[0].dataset)
|
||||||
|
|
||||||
# # enable eval mode
|
# init coco evaluator
|
||||||
# # self.detector.eval()
|
self.coco_evaluator = CocoEvaluator(self.coco, self.iou_types)
|
||||||
|
|
||||||
# # make a prediction
|
def validation_step(self, batch, batch_idx):
|
||||||
# preds = self.model(images)
|
# unpack batch
|
||||||
|
images, targets = batch
|
||||||
|
|
||||||
# # compute validation loss
|
# compute metrics using pycocotools
|
||||||
# self.val_loss = torch.mean(
|
outputs = self.model(images)
|
||||||
# torch.stack(
|
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
|
||||||
# [
|
self.coco_evaluator.update(res)
|
||||||
# self.accuracy(
|
|
||||||
# target,
|
|
||||||
# pred["boxes"],
|
|
||||||
# iou_threshold=0.5,
|
|
||||||
# )
|
|
||||||
# for target, pred in zip(targets, preds)
|
|
||||||
# ],
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return self.val_loss
|
# compute validation loss
|
||||||
|
self.model.train()
|
||||||
|
loss_dict = self.model(images, targets)
|
||||||
|
loss_dict = {f"valid/{key}": val for key, val in loss_dict.items()}
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
# def accuracy(self, src_boxes, pred_boxes, iou_threshold=1.0):
|
return loss_dict
|
||||||
# """
|
|
||||||
# The accuracy method is not the one used in the evaluator but very similar
|
|
||||||
# """
|
|
||||||
# total_gt = len(src_boxes)
|
|
||||||
# total_pred = len(pred_boxes)
|
|
||||||
# if total_gt > 0 and total_pred > 0:
|
|
||||||
|
|
||||||
# # Define the matcher and distance matrix based on iou
|
def validation_epoch_end(self, outputs):
|
||||||
# matcher = Matcher(iou_threshold, iou_threshold, allow_low_quality_matches=False)
|
# accumulate all predictions
|
||||||
# match_quality_matrix = box_iou(src_boxes, pred_boxes)
|
self.coco_evaluator.accumulate()
|
||||||
|
self.coco_evaluator.summarize()
|
||||||
|
|
||||||
# results = matcher(match_quality_matrix)
|
YEET = {
|
||||||
|
"valid,bbox,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[0],
|
||||||
|
"valid,bbox,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[1],
|
||||||
|
"valid,bbox,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[2],
|
||||||
|
"valid,bbox,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[3],
|
||||||
|
"valid,bbox,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[4],
|
||||||
|
"valid,bbox,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[5],
|
||||||
|
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["bbox"].stats[6],
|
||||||
|
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["bbox"].stats[7],
|
||||||
|
"valid,bbox,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[8],
|
||||||
|
"valid,bbox,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[9],
|
||||||
|
"valid,bbox,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[10],
|
||||||
|
"valid,bbox,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["bbox"].stats[11],
|
||||||
|
"valid,segm,AP,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[0],
|
||||||
|
"valid,segm,AP,IoU=0.50,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[1],
|
||||||
|
"valid,segm,AP,IoU=0.75,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[2],
|
||||||
|
"valid,segm,AP,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[3],
|
||||||
|
"valid,segm,AP,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[4],
|
||||||
|
"valid,segm,AP,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[5],
|
||||||
|
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=1": self.coco_evaluator.coco_eval["segm"].stats[6],
|
||||||
|
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=10": self.coco_evaluator.coco_eval["segm"].stats[7],
|
||||||
|
"valid,segm,AR,IoU=0.50:0.,area=all,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[8],
|
||||||
|
"valid,segm,AR,IoU=0.50:0.,area=small,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[9],
|
||||||
|
"valid,segm,AR,IoU=0.50:0.,area=medium,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[10],
|
||||||
|
"valid,segm,AR,IoU=0.50:0.,area=large,maxDets=100": self.coco_evaluator.coco_eval["segm"].stats[11],
|
||||||
|
}
|
||||||
|
|
||||||
# true_positive = torch.count_nonzero(results.unique() != -1)
|
self.log_dict(YEET)
|
||||||
# matched_elements = results[results > -1]
|
|
||||||
|
|
||||||
# # in Matcher, a pred element can be matched only twice
|
|
||||||
# false_positive = torch.count_nonzero(results == -1) + (
|
|
||||||
# len(matched_elements) - len(matched_elements.unique())
|
|
||||||
# )
|
|
||||||
# false_negative = total_gt - true_positive
|
|
||||||
|
|
||||||
# return true_positive / (true_positive + false_positive + false_negative)
|
|
||||||
|
|
||||||
# elif total_gt == 0:
|
|
||||||
# if total_pred > 0:
|
|
||||||
# return torch.tensor(0.0).cuda()
|
|
||||||
# else:
|
|
||||||
# return torch.tensor(1.0).cuda()
|
|
||||||
# elif total_gt > 0 and total_pred == 0:
|
|
||||||
# return torch.tensor(0.0).cuda()
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.SGD(
|
optimizer = torch.optim.SGD(
|
||||||
|
|
12
src/train.py
12
src/train.py
|
@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||||
pl.seed_everything(wandb.config.SEED, workers=True)
|
pl.seed_everything(wandb.config.SEED, workers=True)
|
||||||
|
|
||||||
# Create Network
|
# Create Network
|
||||||
model = MRCNNModule(
|
module = MRCNNModule(
|
||||||
hidden_layer_size=-1,
|
hidden_layer_size=-1,
|
||||||
n_classes=2,
|
n_classes=2,
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||||
# model.load_state_dict(state_dict)
|
# model.load_state_dict(state_dict)
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(model.model, log="all")
|
logger.watch(module.model, log="all")
|
||||||
|
|
||||||
# Create the dataloaders
|
# Create the dataloaders
|
||||||
datamodule = Spheres()
|
datamodule = Spheres()
|
||||||
|
@ -51,14 +51,16 @@ if __name__ == "__main__":
|
||||||
precision=wandb.config.PRECISION,
|
precision=wandb.config.PRECISION,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=5,
|
log_every_n_steps=5,
|
||||||
# val_check_interval=100,
|
val_check_interval=50,
|
||||||
callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
callbacks=[RichProgressBar(), ArtifactLog()],
|
||||||
|
# callbacks=[RichProgressBar(), ArtifactLog(), TableLog()],
|
||||||
# profiler="advanced",
|
# profiler="advanced",
|
||||||
num_sanity_val_steps=0,
|
num_sanity_val_steps=0,
|
||||||
|
devices=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# actually train the model
|
# actually train the model
|
||||||
trainer.fit(model=model, datamodule=datamodule)
|
trainer.fit(model=module, datamodule=datamodule)
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
|
@ -67,31 +67,31 @@ class TableLog(Callback):
|
||||||
|
|
||||||
|
|
||||||
class ArtifactLog(Callback):
|
class ArtifactLog(Callback):
|
||||||
def on_fit_start(self, trainer, pl_module):
|
# def on_fit_start(self, trainer, pl_module):
|
||||||
self.best = 1
|
# self.best = 1
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
# create checkpoint
|
# create checkpoint
|
||||||
torch.save(pl_module.state_dict(), "checkpoints/model.pth")
|
torch.save(pl_module.state_dict(), "checkpoints/model.pth")
|
||||||
|
|
||||||
def on_validation_epoch_start(self, trainer, pl_module):
|
# def on_validation_epoch_start(self, trainer, pl_module):
|
||||||
self.dices = []
|
# self.dices = []
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
# def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
# unpacking
|
# # unpacking
|
||||||
metrics, _ = outputs
|
# metrics, _ = outputs
|
||||||
self.dices.append(metrics["dice"].cpu())
|
# self.dices.append(metrics["dice"].cpu())
|
||||||
|
|
||||||
def on_validation_epoch_end(self, trainer, pl_module):
|
# def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
dice = np.mean(self.dices)
|
# dice = np.mean(self.dices)
|
||||||
|
|
||||||
if dice < self.best:
|
# if dice < self.best:
|
||||||
self.best = dice
|
# self.best = dice
|
||||||
|
|
||||||
# create checkpoint
|
# # create checkpoint
|
||||||
trainer.save_checkpoint("checkpoints/model.ckpt")
|
# trainer.save_checkpoint("checkpoints/model.ckpt")
|
||||||
|
|
||||||
# log artifact
|
# # log artifact
|
||||||
artifact = wandb.Artifact("ckpt", type="model")
|
# artifact = wandb.Artifact("ckpt", type="model")
|
||||||
artifact.add_file("checkpoints/model.ckpt")
|
# artifact.add_file("checkpoints/model.ckpt")
|
||||||
wandb.run.log_artifact(artifact)
|
# wandb.run.log_artifact(artifact)
|
||||||
|
|
194
src/utils/coco_eval.py
Normal file
194
src/utils/coco_eval.py
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
import copy
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pycocotools.mask as mask_util
|
||||||
|
import torch
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
|
|
||||||
|
class CocoEvaluator:
|
||||||
|
def __init__(self, coco_gt, iou_types):
|
||||||
|
if not isinstance(iou_types, (list, tuple)):
|
||||||
|
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
|
||||||
|
coco_gt = copy.deepcopy(coco_gt)
|
||||||
|
self.coco_gt = coco_gt
|
||||||
|
|
||||||
|
self.iou_types = iou_types
|
||||||
|
self.coco_eval = {}
|
||||||
|
for iou_type in iou_types:
|
||||||
|
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
||||||
|
|
||||||
|
self.img_ids = []
|
||||||
|
self.eval_imgs = {k: [] for k in iou_types}
|
||||||
|
|
||||||
|
def update(self, predictions):
|
||||||
|
img_ids = list(np.unique(list(predictions.keys())))
|
||||||
|
self.img_ids.extend(img_ids)
|
||||||
|
|
||||||
|
for iou_type in self.iou_types:
|
||||||
|
results = self.prepare(predictions, iou_type)
|
||||||
|
with redirect_stdout(io.StringIO()):
|
||||||
|
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
||||||
|
coco_eval = self.coco_eval[iou_type]
|
||||||
|
|
||||||
|
coco_eval.cocoDt = coco_dt
|
||||||
|
coco_eval.params.imgIds = list(img_ids)
|
||||||
|
img_ids, eval_imgs = evaluate(coco_eval)
|
||||||
|
|
||||||
|
self.eval_imgs[iou_type].append(eval_imgs)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for iou_type in self.iou_types:
|
||||||
|
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
||||||
|
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
||||||
|
|
||||||
|
def accumulate(self):
|
||||||
|
for coco_eval in self.coco_eval.values():
|
||||||
|
coco_eval.accumulate()
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
for iou_type, coco_eval in self.coco_eval.items():
|
||||||
|
print(f"IoU metric: {iou_type}")
|
||||||
|
coco_eval.summarize()
|
||||||
|
|
||||||
|
def prepare(self, predictions, iou_type):
|
||||||
|
if iou_type == "bbox":
|
||||||
|
return self.prepare_for_coco_detection(predictions)
|
||||||
|
if iou_type == "segm":
|
||||||
|
return self.prepare_for_coco_segmentation(predictions)
|
||||||
|
if iou_type == "keypoints":
|
||||||
|
return self.prepare_for_coco_keypoint(predictions)
|
||||||
|
raise ValueError(f"Unknown iou type {iou_type}")
|
||||||
|
|
||||||
|
def prepare_for_coco_detection(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = prediction["boxes"]
|
||||||
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"bbox": box,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, box in enumerate(boxes)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
def prepare_for_coco_segmentation(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores = prediction["scores"]
|
||||||
|
labels = prediction["labels"]
|
||||||
|
masks = prediction["masks"]
|
||||||
|
|
||||||
|
masks = masks > 0.5
|
||||||
|
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
|
||||||
|
rles = [
|
||||||
|
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
||||||
|
for mask in masks.cpu()
|
||||||
|
]
|
||||||
|
for rle in rles:
|
||||||
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"segmentation": rle,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, rle in enumerate(rles)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
def prepare_for_coco_keypoint(self, predictions):
|
||||||
|
coco_results = []
|
||||||
|
for original_id, prediction in predictions.items():
|
||||||
|
if len(prediction) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = prediction["boxes"]
|
||||||
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
|
scores = prediction["scores"].tolist()
|
||||||
|
labels = prediction["labels"].tolist()
|
||||||
|
keypoints = prediction["keypoints"]
|
||||||
|
keypoints = keypoints.flatten(start_dim=1).tolist()
|
||||||
|
|
||||||
|
coco_results.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image_id": original_id,
|
||||||
|
"category_id": labels[k],
|
||||||
|
"keypoints": keypoint,
|
||||||
|
"score": scores[k],
|
||||||
|
}
|
||||||
|
for k, keypoint in enumerate(keypoints)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return coco_results
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_xywh(boxes):
|
||||||
|
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
||||||
|
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(img_ids, eval_imgs):
|
||||||
|
all_img_ids = utils.all_gather(img_ids)
|
||||||
|
all_eval_imgs = utils.all_gather(eval_imgs)
|
||||||
|
|
||||||
|
merged_img_ids = []
|
||||||
|
for p in all_img_ids:
|
||||||
|
merged_img_ids.extend(p)
|
||||||
|
|
||||||
|
merged_eval_imgs = []
|
||||||
|
for p in all_eval_imgs:
|
||||||
|
merged_eval_imgs.append(p)
|
||||||
|
|
||||||
|
merged_img_ids = np.array(merged_img_ids)
|
||||||
|
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
||||||
|
|
||||||
|
# keep only unique (and in sorted order) images
|
||||||
|
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
||||||
|
merged_eval_imgs = merged_eval_imgs[..., idx]
|
||||||
|
|
||||||
|
return merged_img_ids, merged_eval_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
||||||
|
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
||||||
|
img_ids = list(img_ids)
|
||||||
|
eval_imgs = list(eval_imgs.flatten())
|
||||||
|
|
||||||
|
coco_eval.evalImgs = eval_imgs
|
||||||
|
coco_eval.params.imgIds = img_ids
|
||||||
|
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(imgs):
|
||||||
|
with redirect_stdout(io.StringIO()):
|
||||||
|
imgs.evaluate()
|
||||||
|
return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
|
232
src/utils/coco_utils.py
Normal file
232
src/utils/coco_utils.py
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torchvision
|
||||||
|
from pycocotools import mask as coco_mask
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
|
||||||
|
|
||||||
|
class FilterAndRemapCocoCategories:
|
||||||
|
def __init__(self, categories, remap=True):
|
||||||
|
self.categories = categories
|
||||||
|
self.remap = remap
|
||||||
|
|
||||||
|
def __call__(self, image, target):
|
||||||
|
anno = target["annotations"]
|
||||||
|
anno = [obj for obj in anno if obj["category_id"] in self.categories]
|
||||||
|
if not self.remap:
|
||||||
|
target["annotations"] = anno
|
||||||
|
return image, target
|
||||||
|
anno = copy.deepcopy(anno)
|
||||||
|
for obj in anno:
|
||||||
|
obj["category_id"] = self.categories.index(obj["category_id"])
|
||||||
|
target["annotations"] = anno
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
def convert_coco_poly_to_mask(segmentations, height, width):
|
||||||
|
masks = []
|
||||||
|
for polygons in segmentations:
|
||||||
|
rles = coco_mask.frPyObjects(polygons, height, width)
|
||||||
|
mask = coco_mask.decode(rles)
|
||||||
|
if len(mask.shape) < 3:
|
||||||
|
mask = mask[..., None]
|
||||||
|
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
||||||
|
mask = mask.any(dim=2)
|
||||||
|
masks.append(mask)
|
||||||
|
if masks:
|
||||||
|
masks = torch.stack(masks, dim=0)
|
||||||
|
else:
|
||||||
|
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
||||||
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertCocoPolysToMask:
|
||||||
|
def __call__(self, image, target):
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
image_id = target["image_id"]
|
||||||
|
image_id = torch.tensor([image_id])
|
||||||
|
|
||||||
|
anno = target["annotations"]
|
||||||
|
|
||||||
|
anno = [obj for obj in anno if obj["iscrowd"] == 0]
|
||||||
|
|
||||||
|
boxes = [obj["bbox"] for obj in anno]
|
||||||
|
# guard against no boxes via resizing
|
||||||
|
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
||||||
|
boxes[:, 2:] += boxes[:, :2]
|
||||||
|
boxes[:, 0::2].clamp_(min=0, max=w)
|
||||||
|
boxes[:, 1::2].clamp_(min=0, max=h)
|
||||||
|
|
||||||
|
classes = [obj["category_id"] for obj in anno]
|
||||||
|
classes = torch.tensor(classes, dtype=torch.int64)
|
||||||
|
|
||||||
|
segmentations = [obj["segmentation"] for obj in anno]
|
||||||
|
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
||||||
|
|
||||||
|
keypoints = None
|
||||||
|
if anno and "keypoints" in anno[0]:
|
||||||
|
keypoints = [obj["keypoints"] for obj in anno]
|
||||||
|
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
||||||
|
num_keypoints = keypoints.shape[0]
|
||||||
|
if num_keypoints:
|
||||||
|
keypoints = keypoints.view(num_keypoints, -1, 3)
|
||||||
|
|
||||||
|
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||||
|
boxes = boxes[keep]
|
||||||
|
classes = classes[keep]
|
||||||
|
masks = masks[keep]
|
||||||
|
if keypoints is not None:
|
||||||
|
keypoints = keypoints[keep]
|
||||||
|
|
||||||
|
target = {}
|
||||||
|
target["boxes"] = boxes
|
||||||
|
target["labels"] = classes
|
||||||
|
target["masks"] = masks
|
||||||
|
target["image_id"] = image_id
|
||||||
|
if keypoints is not None:
|
||||||
|
target["keypoints"] = keypoints
|
||||||
|
|
||||||
|
# for conversion to coco api
|
||||||
|
area = torch.tensor([obj["area"] for obj in anno])
|
||||||
|
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
|
||||||
|
target["area"] = area
|
||||||
|
target["iscrowd"] = iscrowd
|
||||||
|
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
||||||
|
def _has_only_empty_bbox(anno):
|
||||||
|
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
||||||
|
|
||||||
|
def _count_visible_keypoints(anno):
|
||||||
|
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
||||||
|
|
||||||
|
min_keypoints_per_image = 10
|
||||||
|
|
||||||
|
def _has_valid_annotation(anno):
|
||||||
|
# if it's empty, there is no annotation
|
||||||
|
if len(anno) == 0:
|
||||||
|
return False
|
||||||
|
# if all boxes have close to zero area, there is no annotation
|
||||||
|
if _has_only_empty_bbox(anno):
|
||||||
|
return False
|
||||||
|
# keypoints task have a slight different critera for considering
|
||||||
|
# if an annotation is valid
|
||||||
|
if "keypoints" not in anno[0]:
|
||||||
|
return True
|
||||||
|
# for keypoint detection tasks, only consider valid images those
|
||||||
|
# containing at least min_keypoints_per_image
|
||||||
|
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(dataset, torchvision.datasets.CocoDetection):
|
||||||
|
raise TypeError(
|
||||||
|
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
|
||||||
|
)
|
||||||
|
ids = []
|
||||||
|
for ds_idx, img_id in enumerate(dataset.ids):
|
||||||
|
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||||
|
anno = dataset.coco.loadAnns(ann_ids)
|
||||||
|
if cat_list:
|
||||||
|
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
||||||
|
if _has_valid_annotation(anno):
|
||||||
|
ids.append(ds_idx)
|
||||||
|
|
||||||
|
dataset = torch.utils.data.Subset(dataset, ids)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_iou_types(model):
|
||||||
|
model_without_ddp = model
|
||||||
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||||
|
model_without_ddp = model.module
|
||||||
|
iou_types = ["bbox"]
|
||||||
|
if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
|
||||||
|
iou_types.append("segm")
|
||||||
|
if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
|
||||||
|
iou_types.append("keypoints")
|
||||||
|
return iou_types
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_coco_api(ds):
|
||||||
|
coco_ds = COCO()
|
||||||
|
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
|
||||||
|
ann_id = 1
|
||||||
|
dataset = {"images": [], "categories": [], "annotations": []}
|
||||||
|
categories = set()
|
||||||
|
# for img_idx in range(len(ds)):
|
||||||
|
for img_idx in range(50):
|
||||||
|
# find better way to get target
|
||||||
|
# targets = ds.get_annotations(img_idx)
|
||||||
|
img, targets = ds[img_idx]
|
||||||
|
image_id = targets["image_id"].item()
|
||||||
|
img_dict = {}
|
||||||
|
img_dict["id"] = image_id
|
||||||
|
img_dict["height"] = img.shape[-2]
|
||||||
|
img_dict["width"] = img.shape[-1]
|
||||||
|
dataset["images"].append(img_dict)
|
||||||
|
bboxes = targets["boxes"].clone()
|
||||||
|
bboxes[:, 2:] -= bboxes[:, :2]
|
||||||
|
bboxes = bboxes.tolist()
|
||||||
|
labels = targets["labels"].tolist()
|
||||||
|
areas = targets["area"].tolist()
|
||||||
|
iscrowd = targets["iscrowd"].tolist()
|
||||||
|
if "masks" in targets:
|
||||||
|
masks = targets["masks"]
|
||||||
|
# make masks Fortran contiguous for coco_mask
|
||||||
|
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
||||||
|
if "keypoints" in targets:
|
||||||
|
keypoints = targets["keypoints"]
|
||||||
|
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
|
||||||
|
num_objs = len(bboxes)
|
||||||
|
for i in range(num_objs):
|
||||||
|
ann = {}
|
||||||
|
ann["image_id"] = image_id
|
||||||
|
ann["bbox"] = bboxes[i]
|
||||||
|
ann["category_id"] = labels[i]
|
||||||
|
categories.add(labels[i])
|
||||||
|
ann["area"] = areas[i]
|
||||||
|
ann["iscrowd"] = iscrowd[i]
|
||||||
|
ann["id"] = ann_id
|
||||||
|
if "masks" in targets:
|
||||||
|
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
|
||||||
|
if "keypoints" in targets:
|
||||||
|
ann["keypoints"] = keypoints[i]
|
||||||
|
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
|
||||||
|
dataset["annotations"].append(ann)
|
||||||
|
ann_id += 1
|
||||||
|
dataset["categories"] = [{"id": i} for i in sorted(categories)]
|
||||||
|
coco_ds.dataset = dataset
|
||||||
|
coco_ds.createIndex()
|
||||||
|
return coco_ds
|
||||||
|
|
||||||
|
|
||||||
|
def get_coco_api_from_dataset(dataset):
|
||||||
|
for _ in range(10):
|
||||||
|
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
||||||
|
break
|
||||||
|
if isinstance(dataset, torch.utils.data.Subset):
|
||||||
|
dataset = dataset.dataset
|
||||||
|
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
||||||
|
return dataset.coco
|
||||||
|
return convert_to_coco_api(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
class CocoDetection(torchvision.datasets.CocoDetection):
|
||||||
|
def __init__(self, img_folder, ann_file, transforms):
|
||||||
|
super().__init__(img_folder, ann_file)
|
||||||
|
self._transforms = transforms
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img, target = super().__getitem__(idx)
|
||||||
|
image_id = self.ids[idx]
|
||||||
|
target = dict(image_id=image_id, annotations=target)
|
||||||
|
if self._transforms is not None:
|
||||||
|
img, target = self._transforms(img, target)
|
||||||
|
return img, target
|
13
wandb.yaml
13
wandb.yaml
|
@ -5,8 +5,6 @@ DIR_VALID_IMG:
|
||||||
DIR_SPHERE:
|
DIR_SPHERE:
|
||||||
value: "/media/disk1/lfainsin/SPHERES/"
|
value: "/media/disk1/lfainsin/SPHERES/"
|
||||||
|
|
||||||
FEATURES:
|
|
||||||
value: [8, 16, 32, 64]
|
|
||||||
N_CHANNELS:
|
N_CHANNELS:
|
||||||
value: 3
|
value: 3
|
||||||
N_CLASSES:
|
N_CLASSES:
|
||||||
|
@ -29,17 +27,12 @@ DEVICE:
|
||||||
WORKERS:
|
WORKERS:
|
||||||
value: 16
|
value: 16
|
||||||
|
|
||||||
IMG_SIZE:
|
|
||||||
value: 512
|
|
||||||
SPHERES:
|
|
||||||
value: 3
|
|
||||||
|
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 10
|
value: 10
|
||||||
TRAIN_BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 8
|
value: 10
|
||||||
VAL_BATCH_SIZE:
|
VALID_BATCH_SIZE:
|
||||||
value: 0
|
value: 2
|
||||||
PREFETCH_FACTOR:
|
PREFETCH_FACTOR:
|
||||||
value: 2
|
value: 2
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue