mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
fix: dropping precision 16, unable to export in float16
This commit is contained in:
parent
eb3dabe8d7
commit
9d36719335
|
@ -67,24 +67,24 @@ class MRCNNModule(pl.LightningModule):
|
||||||
self.model = get_model_instance_segmentation(n_classes)
|
self.model = get_model_instance_segmentation(n_classes)
|
||||||
|
|
||||||
# onnx export
|
# onnx export
|
||||||
self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
# self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||||
|
|
||||||
# torchmetrics
|
# torchmetrics
|
||||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||||
|
|
||||||
def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
# def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
||||||
"""Make a forward pass (prediction), usefull for onnx export.
|
# """Make a forward pass (prediction), usefull for onnx export.
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
imgs (torch.Tensor): the images whose prediction we wish to make
|
# imgs (torch.Tensor): the images whose prediction we wish to make
|
||||||
|
|
||||||
Returns:
|
# Returns:
|
||||||
torch.Tensor: the predictions
|
# torch.Tensor: the predictions
|
||||||
"""
|
# """
|
||||||
self.model.eval()
|
# self.model.eval()
|
||||||
pred: Prediction = self.model(imgs)
|
# pred: Prediction = self.model(imgs)
|
||||||
return pred
|
# return pred
|
||||||
|
|
||||||
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
|
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
|
||||||
"""PyTorch training step.
|
"""PyTorch training step.
|
||||||
|
@ -146,15 +146,22 @@ class MRCNNModule(pl.LightningModule):
|
||||||
Args:
|
Args:
|
||||||
outputs (List[Prediction]): list of predictions from validation steps
|
outputs (List[Prediction]): list of predictions from validation steps
|
||||||
"""
|
"""
|
||||||
# compute and log bounding boxes metrics
|
# compute metrics
|
||||||
metric_dict = self.metric_bbox.compute()
|
metric_dict_bbox = self.metric_bbox.compute()
|
||||||
metric_dict = {f"valid/bbox/{key}": val for key, val in metric_dict.items()}
|
metric_dict_segm = self.metric_segm.compute()
|
||||||
self.log_dict(metric_dict)
|
metric_dict_sum = {
|
||||||
|
f"valid/sum/{k}": metric_dict_bbox.get(k, 0) + metric_dict_segm.get(k, 0)
|
||||||
|
for k in set(metric_dict_bbox) & set(metric_dict_segm)
|
||||||
|
}
|
||||||
|
|
||||||
# compute and log semgentation metrics
|
# change metrics keys
|
||||||
metric_dict = self.metric_segm.compute()
|
metric_dict_bbox = {f"valid/bbox/{key}": val for key, val in metric_dict_bbox.items()}
|
||||||
metric_dict = {f"valid/segm/{key}": val for key, val in metric_dict.items()}
|
metric_dict_segm = {f"valid/segm/{key}": val for key, val in metric_dict_segm.items()}
|
||||||
self.log_dict(metric_dict)
|
|
||||||
|
# log metrics
|
||||||
|
self.log_dict(metric_dict_bbox)
|
||||||
|
self.log_dict(metric_dict_segm)
|
||||||
|
self.log_dict(metric_dict_sum)
|
||||||
|
|
||||||
def configure_optimizers(self) -> Dict[str, Any]:
|
def configure_optimizers(self) -> Dict[str, Any]:
|
||||||
"""PyTorch optimizers and Schedulers.
|
"""PyTorch optimizers and Schedulers.
|
||||||
|
|
613
src/predict copy.ipynb
Normal file
613
src/predict copy.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
52
src/tmp.py
Normal file
52
src/tmp.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
from data.dataset import SyntheticDataset
|
||||||
|
from utils import RandomPaste
|
||||||
|
|
||||||
|
transform = A.Compose(
|
||||||
|
[
|
||||||
|
A.LongestMaxSize(max_size=1024),
|
||||||
|
A.Flip(),
|
||||||
|
RandomPaste(5, "/media/disk1/lfainsin/SPHERES/WHITE", "/dev/null"),
|
||||||
|
A.ToGray(p=0.01),
|
||||||
|
A.ISONoise(),
|
||||||
|
A.ImageCompression(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SyntheticDataset(image_dir="/media/disk1/lfainsin/BACKGROUND/coco/", transform=transform)
|
||||||
|
transform = T.ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
|
def render(i, image, mask):
|
||||||
|
image = transform(image)
|
||||||
|
mask = transform(mask)
|
||||||
|
|
||||||
|
path = f"/media/disk1/lfainsin/TRAIN_prerender/{i:06d}/"
|
||||||
|
Path(path).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
image.save(f"{path}/image.jpg")
|
||||||
|
mask.save(f"{path}/MASK.PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def renderlist(list_i, dataset):
|
||||||
|
for i in list_i:
|
||||||
|
image, mask = dataset[i]
|
||||||
|
render(i, image, mask)
|
||||||
|
|
||||||
|
|
||||||
|
sublists = np.array_split(range(len(dataset)), 16 * 5)
|
||||||
|
threads = []
|
||||||
|
for sublist in sublists:
|
||||||
|
t = Thread(target=renderlist, args=(sublist, dataset))
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
# join all threads
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
10
src/train.py
10
src/train.py
|
@ -35,7 +35,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Create Network
|
# Create Network
|
||||||
module = MRCNNModule(
|
module = MRCNNModule(
|
||||||
n_classes=3,
|
n_classes=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
|
@ -59,17 +59,19 @@ if __name__ == "__main__":
|
||||||
precision=wandb.config.PRECISION,
|
precision=wandb.config.PRECISION,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=5,
|
log_every_n_steps=5,
|
||||||
val_check_interval=200,
|
val_check_interval=250,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
EarlyStopping(monitor="valid/bbox/map", mode="max", patience=10, min_delta=0.01),
|
EarlyStopping(monitor="valid/sum/map", mode="max", patience=10, min_delta=0.01),
|
||||||
ModelCheckpoint(monitor="valid/bbox/map", mode="max"),
|
ModelCheckpoint(monitor="valid/sum/map", mode="max"),
|
||||||
# ModelPruning("l1_unstructured", amount=0.5),
|
# ModelPruning("l1_unstructured", amount=0.5),
|
||||||
LearningRateMonitor(log_momentum=True),
|
LearningRateMonitor(log_momentum=True),
|
||||||
|
# StochasticWeightAveraging(swa_lrs=1e-2),
|
||||||
RichModelSummary(max_depth=2),
|
RichModelSummary(max_depth=2),
|
||||||
RichProgressBar(),
|
RichProgressBar(),
|
||||||
TableLog(),
|
TableLog(),
|
||||||
],
|
],
|
||||||
# profiler="advanced",
|
# profiler="advanced",
|
||||||
|
gradient_clip_val=1,
|
||||||
num_sanity_val_steps=3,
|
num_sanity_val_steps=3,
|
||||||
devices=[0],
|
devices=[0],
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,12 +2,13 @@ import wandb
|
||||||
from pytorch_lightning.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
|
||||||
columns = [
|
columns = [
|
||||||
"ID",
|
|
||||||
"image",
|
"image",
|
||||||
]
|
]
|
||||||
class_labels = {
|
class_labels = {
|
||||||
1: "sphere",
|
1: "sphere",
|
||||||
2: "sphere_gt",
|
2: "chrome",
|
||||||
|
10: "sphere_gt",
|
||||||
|
20: "chrome_gt",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,20 +20,21 @@ class TableLog(Callback):
|
||||||
# unpacking
|
# unpacking
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
for i, (image, target) in enumerate(
|
for image, target in zip(
|
||||||
zip(
|
images,
|
||||||
images,
|
targets,
|
||||||
targets,
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
i,
|
|
||||||
wandb.Image(
|
wandb.Image(
|
||||||
image.cpu(),
|
image.cpu(),
|
||||||
masks={
|
masks={
|
||||||
"ground_truth": {
|
"ground_truth": {
|
||||||
"mask_data": (target["masks"].cpu().sum(dim=0) > 0.5).int().numpy() * 2,
|
"mask_data": (target["masks"] * target["labels"][:, None, None])
|
||||||
|
.max(dim=0)
|
||||||
|
.values.mul(10)
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -57,12 +59,10 @@ class TableLog(Callback):
|
||||||
# unpacking
|
# unpacking
|
||||||
images, targets = batch
|
images, targets = batch
|
||||||
|
|
||||||
for i, (image, target, pred) in enumerate(
|
for image, target, pred in zip(
|
||||||
zip(
|
images,
|
||||||
images,
|
targets,
|
||||||
targets,
|
outputs,
|
||||||
outputs,
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
box_data_gt = [
|
box_data_gt = [
|
||||||
{
|
{
|
||||||
|
@ -73,7 +73,7 @@ class TableLog(Callback):
|
||||||
"maxY": int(target["boxes"][j][3]),
|
"maxY": int(target["boxes"][j][3]),
|
||||||
},
|
},
|
||||||
"domain": "pixel",
|
"domain": "pixel",
|
||||||
"class_id": 2,
|
"class_id": int(target["labels"][j] * 10),
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
}
|
}
|
||||||
for j in range(len(target["labels"]))
|
for j in range(len(target["labels"]))
|
||||||
|
@ -88,7 +88,7 @@ class TableLog(Callback):
|
||||||
"maxY": int(pred["boxes"][j][3]),
|
"maxY": int(pred["boxes"][j][3]),
|
||||||
},
|
},
|
||||||
"domain": "pixel",
|
"domain": "pixel",
|
||||||
"class_id": 1,
|
"class_id": int(pred["labels"][j]),
|
||||||
"box_caption": f"{pred['scores'][j]:0.3f}",
|
"box_caption": f"{pred['scores'][j]:0.3f}",
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
}
|
}
|
||||||
|
@ -97,16 +97,22 @@ class TableLog(Callback):
|
||||||
|
|
||||||
self.rows.append(
|
self.rows.append(
|
||||||
[
|
[
|
||||||
i,
|
|
||||||
wandb.Image(
|
wandb.Image(
|
||||||
image.cpu(),
|
image.cpu(),
|
||||||
masks={
|
masks={
|
||||||
"ground_truth": {
|
"ground_truth": {
|
||||||
"mask_data": target["masks"].cpu().sum(dim=0).int().numpy() * 2,
|
"mask_data": (target["masks"] * target["labels"][:, None, None])
|
||||||
|
.max(dim=0)
|
||||||
|
.values.mul(10)
|
||||||
|
.cpu()
|
||||||
|
.numpy(),
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
"predictions": {
|
"predictions": {
|
||||||
"mask_data": pred["masks"].cpu().sum(dim=0).int().numpy(),
|
"mask_data": (pred["masks"] * pred["labels"][:, None, None])
|
||||||
|
.max(dim=0)
|
||||||
|
.values.cpu()
|
||||||
|
.numpy(),
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -46,18 +46,16 @@ class RandomPaste(A.DualTransform):
|
||||||
self.scale_range = scale_range
|
self.scale_range = scale_range
|
||||||
self.nb = nb
|
self.nb = nb
|
||||||
|
|
||||||
self.augmentation_datas: List[AugmentationData] = []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def targets_as_params(self):
|
def targets_as_params(self):
|
||||||
return ["image"]
|
return ["image"]
|
||||||
|
|
||||||
def apply(self, img, **params):
|
def apply(self, img, augmentation_datas, **params):
|
||||||
# convert img to Image, needed for `paste` function
|
# convert img to Image, needed for `paste` function
|
||||||
img = Image.fromarray(img)
|
img = Image.fromarray(img)
|
||||||
|
|
||||||
# paste spheres
|
# paste spheres
|
||||||
for augmentation in self.augmentation_datas:
|
for augmentation in augmentation_datas:
|
||||||
paste_img_aug = T.functional.adjust_contrast(
|
paste_img_aug = T.functional.adjust_contrast(
|
||||||
augmentation.paste_img,
|
augmentation.paste_img,
|
||||||
contrast_factor=augmentation.contrast,
|
contrast_factor=augmentation.contrast,
|
||||||
|
@ -98,11 +96,11 @@ class RandomPaste(A.DualTransform):
|
||||||
|
|
||||||
return np.array(img.convert("RGB"))
|
return np.array(img.convert("RGB"))
|
||||||
|
|
||||||
def apply_to_mask(self, mask, **params):
|
def apply_to_mask(self, mask, augmentation_datas, **params):
|
||||||
# convert mask to Image, needed for `paste` function
|
# convert mask to Image, needed for `paste` function
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask)
|
||||||
|
|
||||||
for augmentation in self.augmentation_datas:
|
for augmentation in augmentation_datas:
|
||||||
paste_mask_aug = T.functional.affine(
|
paste_mask_aug = T.functional.affine(
|
||||||
augmentation.paste_mask,
|
augmentation.paste_mask,
|
||||||
scale=0.95,
|
scale=0.95,
|
||||||
|
@ -125,6 +123,8 @@ class RandomPaste(A.DualTransform):
|
||||||
return np.array(mask.convert("L"))
|
return np.array(mask.convert("L"))
|
||||||
|
|
||||||
def get_params_dependent_on_targets(self, params):
|
def get_params_dependent_on_targets(self, params):
|
||||||
|
# init augmentation list
|
||||||
|
augmentation_datas: List[AugmentationData] = []
|
||||||
|
|
||||||
# load target image (w/ transparency)
|
# load target image (w/ transparency)
|
||||||
target_img = params["image"]
|
target_img = params["image"]
|
||||||
|
@ -133,19 +133,19 @@ class RandomPaste(A.DualTransform):
|
||||||
# generate augmentations
|
# generate augmentations
|
||||||
ite = 0
|
ite = 0
|
||||||
NB = rd.randint(1, self.nb)
|
NB = rd.randint(1, self.nb)
|
||||||
while len(self.augmentation_datas) < NB:
|
while len(augmentation_datas) < NB:
|
||||||
if ite > 100:
|
if ite > 100:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
ite += 1
|
ite += 1
|
||||||
|
|
||||||
# choose a random sphere image and its corresponding mask
|
# choose a random sphere image and its corresponding mask
|
||||||
if rd.random() > 0.5:
|
if rd.random() > 0.5 or len(self.chrome_sphere_images) == 0:
|
||||||
img_path = rd.choice(self.sphere_images)
|
img_path = rd.choice(self.sphere_images)
|
||||||
value = len(self.augmentation_datas) + 1
|
value = len(augmentation_datas) + 1
|
||||||
else:
|
else:
|
||||||
img_path = rd.choice(self.chrome_sphere_images)
|
img_path = rd.choice(self.chrome_sphere_images)
|
||||||
value = 255 - len(self.augmentation_datas)
|
value = 255 - len(augmentation_datas)
|
||||||
mask_path = img_path.parent.joinpath("MASK.PNG")
|
mask_path = img_path.parent.joinpath("MASK.PNG")
|
||||||
|
|
||||||
# load paste assets
|
# load paste assets
|
||||||
|
@ -161,7 +161,7 @@ class RandomPaste(A.DualTransform):
|
||||||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.augmentation_datas.append(
|
augmentation_datas.append(
|
||||||
AugmentationData(
|
AugmentationData(
|
||||||
position=(
|
position=(
|
||||||
rd.randint(0, target_shape[1] - shape[1]),
|
rd.randint(0, target_shape[1] - shape[1]),
|
||||||
|
@ -178,12 +178,19 @@ class RandomPaste(A.DualTransform):
|
||||||
paste_img=paste_img,
|
paste_img=paste_img,
|
||||||
paste_mask=paste_mask,
|
paste_mask=paste_mask,
|
||||||
value=value,
|
value=value,
|
||||||
other_augmentations=self.augmentation_datas,
|
target_shape=tuple(target_shape),
|
||||||
|
other_augmentations=augmentation_datas,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
params.update(
|
||||||
|
{
|
||||||
|
"augmentation_datas": augmentation_datas,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
@ -194,6 +201,7 @@ class AugmentationData:
|
||||||
position: Tuple[int, int]
|
position: Tuple[int, int]
|
||||||
|
|
||||||
shape: Tuple[int, int]
|
shape: Tuple[int, int]
|
||||||
|
target_shape: Tuple[int, int]
|
||||||
angle: float
|
angle: float
|
||||||
|
|
||||||
brightness: float
|
brightness: float
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
DIR_TRAIN_IMG:
|
DIR_TRAIN_IMG:
|
||||||
value: "/media/disk1/lfainsin/TRAIN_prerender_old/"
|
value: "/media/disk1/lfainsin/TRAIN_prerender/"
|
||||||
DIR_VALID_IMG:
|
DIR_VALID_IMG:
|
||||||
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
|
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
|
||||||
# DIR_SPHERE:
|
# DIR_SPHERE:
|
||||||
|
@ -19,7 +19,7 @@ BENCHMARK:
|
||||||
DETERMINISTIC:
|
DETERMINISTIC:
|
||||||
value: False
|
value: False
|
||||||
PRECISION:
|
PRECISION:
|
||||||
value: 16
|
value: 32
|
||||||
SEED:
|
SEED:
|
||||||
value: 69420
|
value: 69420
|
||||||
DEVICE:
|
DEVICE:
|
||||||
|
|
Loading…
Reference in a new issue