mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
fix: dropping precision 16, unable to export in float16
This commit is contained in:
parent
eb3dabe8d7
commit
9d36719335
|
@ -67,24 +67,24 @@ class MRCNNModule(pl.LightningModule):
|
|||
self.model = get_model_instance_segmentation(n_classes)
|
||||
|
||||
# onnx export
|
||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||
# self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||
|
||||
# torchmetrics
|
||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||
|
||||
def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
||||
"""Make a forward pass (prediction), usefull for onnx export.
|
||||
# def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
||||
# """Make a forward pass (prediction), usefull for onnx export.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): the images whose prediction we wish to make
|
||||
# Args:
|
||||
# imgs (torch.Tensor): the images whose prediction we wish to make
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the predictions
|
||||
"""
|
||||
self.model.eval()
|
||||
pred: Prediction = self.model(imgs)
|
||||
return pred
|
||||
# Returns:
|
||||
# torch.Tensor: the predictions
|
||||
# """
|
||||
# self.model.eval()
|
||||
# pred: Prediction = self.model(imgs)
|
||||
# return pred
|
||||
|
||||
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
|
||||
"""PyTorch training step.
|
||||
|
@ -146,15 +146,22 @@ class MRCNNModule(pl.LightningModule):
|
|||
Args:
|
||||
outputs (List[Prediction]): list of predictions from validation steps
|
||||
"""
|
||||
# compute and log bounding boxes metrics
|
||||
metric_dict = self.metric_bbox.compute()
|
||||
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
||||
self.log_dict(metric_dict)
|
||||
# compute metrics
|
||||
metric_dict_bbox = self.metric_bbox.compute()
|
||||
metric_dict_segm = self.metric_segm.compute()
|
||||
metric_dict_sum = {
|
||||
f"valid/sum/{k}": metric_dict_bbox.get(k, 0) + metric_dict_segm.get(k, 0)
|
||||
for k in set(metric_dict_bbox) & set(metric_dict_segm)
|
||||
}
|
||||
|
||||
# compute and log semgentation metrics
|
||||
metric_dict = self.metric_segm.compute()
|
||||
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
||||
self.log_dict(metric_dict)
|
||||
# change metrics keys
|
||||
metric_dict_bbox = {f"valid/bbox/{key}": val for key, val in metric_dict_bbox.items()}
|
||||
metric_dict_segm = {f"valid/segm/{key}": val for key, val in metric_dict_segm.items()}
|
||||
|
||||
# log metrics
|
||||
self.log_dict(metric_dict_bbox)
|
||||
self.log_dict(metric_dict_segm)
|
||||
self.log_dict(metric_dict_sum)
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""PyTorch optimizers and Schedulers.
|
||||
|
|
613
src/predict copy.ipynb
Normal file
613
src/predict copy.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
52
src/tmp.py
Normal file
52
src/tmp.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
|
||||
from data.dataset import SyntheticDataset
|
||||
from utils import RandomPaste
|
||||
|
||||
transform = A.Compose(
|
||||
[
|
||||
A.LongestMaxSize(max_size=1024),
|
||||
A.Flip(),
|
||||
RandomPaste(5, "/media/disk1/lfainsin/SPHERES/WHITE", "/dev/null"),
|
||||
A.ToGray(p=0.01),
|
||||
A.ISONoise(),
|
||||
A.ImageCompression(),
|
||||
],
|
||||
)
|
||||
|
||||
dataset = SyntheticDataset(image_dir="/media/disk1/lfainsin/BACKGROUND/coco/", transform=transform)
|
||||
transform = T.ToPILImage()
|
||||
|
||||
|
||||
def render(i, image, mask):
|
||||
image = transform(image)
|
||||
mask = transform(mask)
|
||||
|
||||
path = f"/media/disk1/lfainsin/TRAIN_prerender/{i:06d}/"
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image.save(f"{path}/image.jpg")
|
||||
mask.save(f"{path}/MASK.PNG")
|
||||
|
||||
|
||||
def renderlist(list_i, dataset):
|
||||
for i in list_i:
|
||||
image, mask = dataset[i]
|
||||
render(i, image, mask)
|
||||
|
||||
|
||||
sublists = np.array_split(range(len(dataset)), 16 * 5)
|
||||
threads = []
|
||||
for sublist in sublists:
|
||||
t = Thread(target=renderlist, args=(sublist, dataset))
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
# join all threads
|
||||
for t in threads:
|
||||
t.join()
|
10
src/train.py
10
src/train.py
|
@ -35,7 +35,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Create Network
|
||||
module = MRCNNModule(
|
||||
n_classes=3,
|
||||
n_classes=2,
|
||||
)
|
||||
|
||||
# load checkpoint
|
||||
|
@ -59,17 +59,19 @@ if __name__ == "__main__":
|
|||
precision=wandb.config.PRECISION,
|
||||
logger=logger,
|
||||
log_every_n_steps=5,
|
||||
val_check_interval=200,
|
||||
val_check_interval=250,
|
||||
callbacks=[
|
||||
EarlyStopping(monitor="valid/bbox/map", mode="max", patience=10, min_delta=0.01),
|
||||
ModelCheckpoint(monitor="valid/bbox/map", mode="max"),
|
||||
EarlyStopping(monitor="valid/sum/map", mode="max", patience=10, min_delta=0.01),
|
||||
ModelCheckpoint(monitor="valid/sum/map", mode="max"),
|
||||
# ModelPruning("l1_unstructured", amount=0.5),
|
||||
LearningRateMonitor(log_momentum=True),
|
||||
# StochasticWeightAveraging(swa_lrs=1e-2),
|
||||
RichModelSummary(max_depth=2),
|
||||
RichProgressBar(),
|
||||
TableLog(),
|
||||
],
|
||||
# profiler="advanced",
|
||||
gradient_clip_val=1,
|
||||
num_sanity_val_steps=3,
|
||||
devices=[0],
|
||||
)
|
||||
|
|
|
@ -2,12 +2,13 @@ import wandb
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
columns = [
|
||||
"ID",
|
||||
"image",
|
||||
]
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
2: "sphere_gt",
|
||||
2: "chrome",
|
||||
10: "sphere_gt",
|
||||
20: "chrome_gt",
|
||||
}
|
||||
|
||||
|
||||
|
@ -19,20 +20,21 @@ class TableLog(Callback):
|
|||
# unpacking
|
||||
images, targets = batch
|
||||
|
||||
for i, (image, target) in enumerate(
|
||||
zip(
|
||||
images,
|
||||
targets,
|
||||
)
|
||||
for image, target in zip(
|
||||
images,
|
||||
targets,
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(
|
||||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy() * 2,
|
||||
"mask_data": (target["masks"] * target["labels"][:, None, None])
|
||||
.max(dim=0)
|
||||
.values.mul(10)
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
|
@ -57,12 +59,10 @@ class TableLog(Callback):
|
|||
# unpacking
|
||||
images, targets = batch
|
||||
|
||||
for i, (image, target, pred) in enumerate(
|
||||
zip(
|
||||
images,
|
||||
targets,
|
||||
outputs,
|
||||
)
|
||||
for image, target, pred in zip(
|
||||
images,
|
||||
targets,
|
||||
outputs,
|
||||
):
|
||||
box_data_gt = [
|
||||
{
|
||||
|
@ -73,7 +73,7 @@ class TableLog(Callback):
|
|||
"maxY": int(target["boxes"][j][3]),
|
||||
},
|
||||
"domain": "pixel",
|
||||
"class_id": 2,
|
||||
"class_id": int(target["labels"][j] * 10),
|
||||
"class_labels": class_labels,
|
||||
}
|
||||
for j in range(len(target["labels"]))
|
||||
|
@ -88,7 +88,7 @@ class TableLog(Callback):
|
|||
"maxY": int(pred["boxes"][j][3]),
|
||||
},
|
||||
"domain": "pixel",
|
||||
"class_id": 1,
|
||||
"class_id": int(pred["labels"][j]),
|
||||
"box_caption": f"{pred['scores'][j]:0.3f}",
|
||||
"class_labels": class_labels,
|
||||
}
|
||||
|
@ -97,16 +97,22 @@ class TableLog(Callback):
|
|||
|
||||
self.rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(
|
||||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": target["masks"].cpu().sum(dim=0).int().numpy() * 2,
|
||||
"mask_data": (target["masks"] * target["labels"][:, None, None])
|
||||
.max(dim=0)
|
||||
.values.mul(10)
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
"predictions": {
|
||||
"mask_data": pred["masks"].cpu().sum(dim=0).int().numpy(),
|
||||
"mask_data": (pred["masks"] * pred["labels"][:, None, None])
|
||||
.max(dim=0)
|
||||
.values.cpu()
|
||||
.numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -46,18 +46,16 @@ class RandomPaste(A.DualTransform):
|
|||
self.scale_range = scale_range
|
||||
self.nb = nb
|
||||
|
||||
self.augmentation_datas: List[AugmentationData] = []
|
||||
|
||||
@property
|
||||
def targets_as_params(self):
|
||||
return ["image"]
|
||||
|
||||
def apply(self, img, **params):
|
||||
def apply(self, img, augmentation_datas, **params):
|
||||
# convert img to Image, needed for `paste` function
|
||||
img = Image.fromarray(img)
|
||||
|
||||
# paste spheres
|
||||
for augmentation in self.augmentation_datas:
|
||||
for augmentation in augmentation_datas:
|
||||
paste_img_aug = T.functional.adjust_contrast(
|
||||
augmentation.paste_img,
|
||||
contrast_factor=augmentation.contrast,
|
||||
|
@ -98,11 +96,11 @@ class RandomPaste(A.DualTransform):
|
|||
|
||||
return np.array(img.convert("RGB"))
|
||||
|
||||
def apply_to_mask(self, mask, **params):
|
||||
def apply_to_mask(self, mask, augmentation_datas, **params):
|
||||
# convert mask to Image, needed for `paste` function
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
for augmentation in self.augmentation_datas:
|
||||
for augmentation in augmentation_datas:
|
||||
paste_mask_aug = T.functional.affine(
|
||||
augmentation.paste_mask,
|
||||
scale=0.95,
|
||||
|
@ -125,6 +123,8 @@ class RandomPaste(A.DualTransform):
|
|||
return np.array(mask.convert("L"))
|
||||
|
||||
def get_params_dependent_on_targets(self, params):
|
||||
# init augmentation list
|
||||
augmentation_datas: List[AugmentationData] = []
|
||||
|
||||
# load target image (w/ transparency)
|
||||
target_img = params["image"]
|
||||
|
@ -133,19 +133,19 @@ class RandomPaste(A.DualTransform):
|
|||
# generate augmentations
|
||||
ite = 0
|
||||
NB = rd.randint(1, self.nb)
|
||||
while len(self.augmentation_datas) < NB:
|
||||
while len(augmentation_datas) < NB:
|
||||
if ite > 100:
|
||||
break
|
||||
else:
|
||||
ite += 1
|
||||
|
||||
# choose a random sphere image and its corresponding mask
|
||||
if rd.random() > 0.5:
|
||||
if rd.random() > 0.5 or len(self.chrome_sphere_images) == 0:
|
||||
img_path = rd.choice(self.sphere_images)
|
||||
value = len(self.augmentation_datas) + 1
|
||||
value = len(augmentation_datas) + 1
|
||||
else:
|
||||
img_path = rd.choice(self.chrome_sphere_images)
|
||||
value = 255 - len(self.augmentation_datas)
|
||||
value = 255 - len(augmentation_datas)
|
||||
mask_path = img_path.parent.joinpath("MASK.PNG")
|
||||
|
||||
# load paste assets
|
||||
|
@ -161,7 +161,7 @@ class RandomPaste(A.DualTransform):
|
|||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||
|
||||
try:
|
||||
self.augmentation_datas.append(
|
||||
augmentation_datas.append(
|
||||
AugmentationData(
|
||||
position=(
|
||||
rd.randint(0, target_shape[1] - shape[1]),
|
||||
|
@ -178,12 +178,19 @@ class RandomPaste(A.DualTransform):
|
|||
paste_img=paste_img,
|
||||
paste_mask=paste_mask,
|
||||
value=value,
|
||||
other_augmentations=self.augmentation_datas,
|
||||
target_shape=tuple(target_shape),
|
||||
other_augmentations=augmentation_datas,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
params.update(
|
||||
{
|
||||
"augmentation_datas": augmentation_datas,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
|
@ -194,6 +201,7 @@ class AugmentationData:
|
|||
position: Tuple[int, int]
|
||||
|
||||
shape: Tuple[int, int]
|
||||
target_shape: Tuple[int, int]
|
||||
angle: float
|
||||
|
||||
brightness: float
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
DIR_TRAIN_IMG:
|
||||
value: "/media/disk1/lfainsin/TRAIN_prerender_old/"
|
||||
value: "/media/disk1/lfainsin/TRAIN_prerender/"
|
||||
DIR_VALID_IMG:
|
||||
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
|
||||
# DIR_SPHERE:
|
||||
|
@ -19,7 +19,7 @@ BENCHMARK:
|
|||
DETERMINISTIC:
|
||||
value: False
|
||||
PRECISION:
|
||||
value: 16
|
||||
value: 32
|
||||
SEED:
|
||||
value: 69420
|
||||
DEVICE:
|
||||
|
|
Loading…
Reference in a new issue