refactor!: move files to python modules

feat: add back (broken) faster rcnn model
This commit is contained in:
Laurent Fainsin 2023-04-05 20:11:01 +02:00
parent 8f24f83b4a
commit d8bba4481a
12 changed files with 464 additions and 35 deletions

View file

@ -17,8 +17,8 @@ dependencies:
- pytorch
- torchvision
- torchaudio
- pytorch-cuda=11.8
- lightning
- pytorch-cuda
- lightning # currently broken, install manually with pip
# deep learning libraries
- transformers
- datasets

View file

@ -1,7 +1,7 @@
import datasets
import torch
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import CombinedLoader
from lightning.pytorch import LightningDataModule
from lightning.pytorch.utilities import CombinedLoader
from torch.utils.data import DataLoader
from torchvision.transforms import AugMix
from transformers import DetrFeatureExtractor
@ -41,9 +41,9 @@ class DETRDataModule(LightningDataModule):
def prepare_data(self):
"""Download data and prepare for training."""
# load datasets
self.illumination = datasets.load_dataset("src/spheres_illumination.py", split="train")
self.render = datasets.load_dataset("src/spheres_synth.py", split="train")
self.real = datasets.load_dataset("src/spheres.py", split="train")
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
# split datasets
self.illumination = self.illumination.train_test_split(test_size=0.01)
@ -56,9 +56,7 @@ class DETRDataModule(LightningDataModule):
print(f"real: {self.real}")
# other datasets
self.test_ds = datasets.load_dataset("src/spheres_illumination.py", split="test")
# self.predict_ds = datasets.load_dataset("src/spheres.py", split="train").shuffle().select(range(16))
self.predict_ds = datasets.load_dataset("src/spheres_predict.py", split="train")
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
# define AugMix transform
self.mix = AugMix()

View file

@ -0,0 +1,314 @@
import datasets
import torch
from lightning.pytorch import LightningDataModule
from lightning.pytorch.utilities import CombinedLoader
from torch.utils.data import DataLoader
from torchvision.transforms import AugMix
from transformers import DetrFeatureExtractor
class FasterRCNNDataModule(LightningDataModule):
"""PyTorch Lightning data module for Faster RCNN."""
def __init__(
self,
num_workers: int = 8,
batch_size: int = 5,
prefetch_factor: int = 2,
model_name: str = "facebook/detr-resnet-50",
persistent_workers: bool = True,
):
"""Constructor.
Args:
num_workers (int, optional): Number of workers.
batch_size (int, optional): Batch size.
prefetch_factor (int, optional): Prefetch factor.
val_split (float, optional): Validation split.
model_name (str, optional): Model name.
"""
super().__init__()
# save params
self.num_workers = num_workers
self.batch_size = batch_size
self.prefetch_factor = prefetch_factor
self.persistent_workers = persistent_workers
# get feature extractor
self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
def prepare_data(self):
"""Download data and prepare for training."""
# load datasets
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
# split datasets
self.illumination = self.illumination.train_test_split(test_size=0.01)
self.render = self.render.train_test_split(test_size=0.01)
self.real = self.real.train_test_split(test_size=0.1)
# print some info
print(f"illumination: {self.illumination}")
print(f"render: {self.render}")
print(f"real: {self.real}")
# other datasets
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
# define AugMix transform
self.mix = AugMix()
# useful mappings
self.labels = self.real["test"].features["objects"][0]["category_id"].names
self.id2label = {k: v for k, v in enumerate(self.labels)}
self.label2id = {v: k for k, v in enumerate(self.labels)}
def train_transform(self, batch):
"""Training transform.
Args:
batch (dict): Batch precollated by HuggingFace datasets.
Structure is similar to the following:
{
"image": list[PIL.Image],
"image_id": list[int],
"objects": [
{
"bbox": list[float, 4],
"category_id": int,
}
]
}
Returns:
dict: Augmented and processed batch.
Structure is similar to the following:
{
"pixel_values": TensorType["batch", "canal", "width", "height"],
"pixel_mask": TensorType["batch", 1200, 1200],
"labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]],
}
"""
# extract images, ids and objects from batch
images = batch["image"]
ids = batch["image_id"]
objects = batch["objects"]
# apply AugMix transform
images_mixed = [self.mix(image) for image in images]
# build targets for feature extractor
targets = [
{
"image_id": id,
"annotations": object,
}
for id, object in zip(ids, objects)
]
# process images and targets with feature extractor for DETR
processed = self.feature_extractor(
images=images_mixed,
annotations=targets,
return_tensors="pt",
)
for label in processed["labels"]:
# renamed "class_labels" to "labels"
# add 1 since 0 is reserved for background
label["labels"] = label["class_labels"] + 1
del label["class_labels"]
# format boxes from [xc, yc, w, h] to [x1, y1, x2, y2]
width_height = label["boxes"][:, 2:]
label["boxes"][:, :2] = label["boxes"][:, :2] - width_height / 2
label["boxes"][:, 2:] = label["boxes"][:, :2] + width_height / 2
# convert from normalized to absolute coordinates
label["boxes"][:, 0] *= label["size"][1]
label["boxes"][:, 1] *= label["size"][0]
label["boxes"][:, 2] *= label["size"][1]
label["boxes"][:, 3] *= label["size"][0]
return processed
def val_transform(self, batch):
"""Validation transform.
Just like Training transform, but without AugMix.
"""
# extract images, ids and objects from batch
images = batch["image"]
ids = batch["image_id"]
objects = batch["objects"]
# build targets for feature extractor
targets = [
{
"image_id": id,
"annotations": object,
}
for id, object in zip(ids, objects)
]
processed = self.feature_extractor(
images=images,
annotations=targets,
return_tensors="pt",
)
for label in processed["labels"]:
# renamed "class_labels" to "labels"
# add 1 since 0 is reserved for background
label["labels"] = label["class_labels"] + 1
del label["class_labels"]
# format boxes from [xcenter, ycenter, w, h] to [x1, y1, x2, y2]
center_x, center_y, width, height = label["boxes"].unbind(-1)
label["boxes"] = torch.stack(
# top left x, top left y, bottom right x, bottom right y
[
(center_x - 0.5 * width),
(center_y - 0.5 * height),
(center_x + 0.5 * width),
(center_y + 0.5 * height),
],
dim=-1,
)
# convert from normalized to absolute coordinates
label["boxes"][:, 0] *= label["size"][1]
label["boxes"][:, 1] *= label["size"][0]
label["boxes"][:, 2] *= label["size"][1]
label["boxes"][:, 3] *= label["size"][0]
return processed
def predict_transform(self, batch):
"""Prediction transform.
Just like val_transform, but with images.
"""
processed = self.val_transform(batch)
# add images to dict
processed["images"] = batch["image"]
return processed
def collate_fn(self, examples):
"""Collate function.
Convert list of dicts to dict of Tensors.
"""
return {
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
"labels": [data["labels"] for data in examples],
}
def collate_fn_predict(self, examples):
"""Collate function.
Convert list of dicts to dict of Tensors.
"""
return {
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
"labels": [data["labels"] for data in examples],
"images": [data["images"] for data in examples],
}
def train_dataloader(self):
"""Training dataloader."""
loaders = {
"illumination": DataLoader(
self.illumination["train"].with_transform(self.val_transform),
shuffle=True,
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"render": DataLoader(
self.render["train"].with_transform(self.val_transform),
shuffle=True,
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"real": DataLoader(
self.real["train"].with_transform(self.val_transform),
shuffle=True,
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
}
return CombinedLoader(loaders, mode="max_size_cycle")
def val_dataloader(self):
"""Validation dataloader."""
loaders = {
"illumination": DataLoader(
self.illumination["test"].with_transform(self.val_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"render": DataLoader(
self.render["test"].with_transform(self.val_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"real": DataLoader(
self.real["test"].with_transform(self.val_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
}
return CombinedLoader(loaders, mode="max_size_cycle")
def predict_dataloader(self):
"""Prediction dataloader."""
return DataLoader(
self.predict_ds.with_transform(self.predict_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
if __name__ == "__main__":
# load data
dm = FasterRCNNDataModule()
dm.prepare_data()
ds = dm.train_dataloader()
for batch in ds:
print(batch)

View file

@ -0,0 +1,2 @@
from DETR import DETRDataModule
from FasterRCNN import FasterRCNNDataModule

View file

@ -7,9 +7,9 @@ import numpy as np
prefix = "/data/local-files/?d=spheres/"
dataset_path = pathlib.Path("./dataset_antoine_laurent/")
annotation_path = dataset_path / "annotations.json"
annotation_path = dataset_path / "annotations.json" # from labelstudio
_VERSION = "1.0.0"
_VERSION = "2.0.0"
_DESCRIPTION = ""
@ -105,6 +105,7 @@ class SphereAntoineLaurent(datasets.GeneratorBasedBuilder):
for annotation in annotations:
assert annotation["value"]["keypointlabels"][0] == label
# convert labels
if label == "White":
label = "Matte"
elif label == "Black":

View file

@ -5,7 +5,7 @@ import datasets
dataset_path_train = pathlib.Path("./dataset_illumination/")
_VERSION = "1.0.0"
_VERSION = "2.0.0"
_DESCRIPTION = ""
@ -90,9 +90,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
meta["gray"]["bounding_box"]["x"] / original_width * width,
meta["gray"]["bounding_box"]["y"] / original_height * height,
meta["gray"]["bounding_box"]["w"] / original_width * width,
meta["gray"]["bounding_box"]["h"] / original_height * height
meta["gray"]["bounding_box"]["h"] / original_height * height,
),
"Matte"
"Matte",
)
chrome = (
@ -100,9 +100,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
meta["chrome"]["bounding_box"]["x"] / original_width * width,
meta["chrome"]["bounding_box"]["y"] / original_height * height,
meta["chrome"]["bounding_box"]["w"] / original_width * width,
meta["chrome"]["bounding_box"]["h"] / original_height * height
meta["chrome"]["bounding_box"]["h"] / original_height * height,
),
"Chrome"
"Chrome",
)
# generate data

View file

@ -4,7 +4,7 @@ import datasets
dataset_path = pathlib.Path("./dataset_predict/")
_VERSION = "1.0.0"
_VERSION = "2.0.2"
_DESCRIPTION = ""
@ -60,7 +60,6 @@ class SpherePredict(datasets.GeneratorBasedBuilder):
jpgs = dataset_path.rglob("*.jpg")
for index, jpg in enumerate(jpgs):
print(index, jpg, 2)
# generate data

View file

@ -4,7 +4,7 @@ import datasets
dataset_path = pathlib.Path("./dataset_render/")
_VERSION = "1.0.0"
_VERSION = "2.0.0"
_DESCRIPTION = ""
@ -20,7 +20,6 @@ _NAMES = [
class SphereSynth(datasets.GeneratorBasedBuilder):
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
@ -78,8 +77,9 @@ class SphereSynth(datasets.GeneratorBasedBuilder):
1 - float(line[1]),
float(line[2]),
1 - float(line[3]),
line[4].strip()
) for line in lines
line[4].strip(),
)
for line in lines
]
bboxes = [

View file

@ -1,11 +1,11 @@
from datamodule import DETRDataModule
from module import DETR
from pytorch_lightning.callbacks import (
from datamodule import DETRDataModule, FasterRCNNDataModule # noqa: F401
from lightning.pytorch.callbacks import (
ModelCheckpoint,
RichModelSummary,
RichProgressBar,
)
from pytorch_lightning.cli import LightningCLI
from lightning.pytorch.cli import LightningCLI
from module import DETR, FasterRCNN # noqa: F401
class MyLightningCLI(LightningCLI):
@ -15,12 +15,11 @@ class MyLightningCLI(LightningCLI):
"""Add arguments to parser."""
parser.set_defaults(
{
"trainer.multiple_trainloader_mode": "max_size_cycle",
"trainer.max_steps": 5000,
"trainer.max_epochs": 1,
"trainer.accelerator": "gpu",
"trainer.devices": "[1]",
"trainer.strategy": "dp",
"trainer.devices": "[0]",
"trainer.strategy": "auto",
"trainer.log_every_n_steps": 25,
"trainer.val_check_interval": 200,
"trainer.num_sanity_val_steps": 10,

View file

@ -1,6 +1,6 @@
import torch
from lightning.pytorch import LightningModule
from PIL import ImageDraw
from pytorch_lightning import LightningModule
from transformers import (
DetrForObjectDetection,
get_cosine_with_hard_restarts_schedule_with_warmup,
@ -33,13 +33,14 @@ class DETR(LightningModule):
"""
super().__init__()
# replace COCO classification head with custom head
# get DETR model
self.net = DetrForObjectDetection.from_pretrained(
"facebook/detr-resnet-50",
ignore_mismatched_sizes=True,
num_queries=num_queries,
num_labels=num_labels,
)
torch.compile(self.net)
# cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
self.lr = lr
@ -47,7 +48,6 @@ class DETR(LightningModule):
self.weight_decay = weight_decay
self.warmup_steps = warmup_steps
self.prediction_threshold = prediction_threshold
self.save_hyperparameters()
def forward(self, pixel_values, pixel_mask, **kwargs):
@ -100,7 +100,7 @@ class DETR(LightningModule):
"""Training step."""
outputs = self.common_step(batch, batch_idx)
# logs metrics for each training_step and the average across the epoch
# logs metrics for each training_step
loss = 0
for dataloader_name, output in outputs.items():
loss += output["loss"]
@ -117,7 +117,7 @@ class DETR(LightningModule):
"""Validation step."""
outputs = self.common_step(batch, batch_idx)
# logs metrics for each validation_step and the average across the epoch
# logs metrics for each validation_step
loss = 0
for dataloader_name, output in outputs.items():
loss += output["loss"]

114
src/module/FasterRCNN.py Normal file
View file

@ -0,0 +1,114 @@
import torch
import torchvision
from lightning.pytorch import LightningModule
from PIL import ImageDraw
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights, FastRCNNPredictor
def get_model_instance_segmentation(n_classes: int):
"""Returns a Torchvision FasterRCNN model for finetunning.
Args:
n_classes (int): number of classes the model should predict, background excluded
"""
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT,
box_detections_per_img=10, # cap numbers of detections, else oom
)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes + 1)
return model
class FasterRCNN(LightningModule):
"""Faster R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
def __init__(
self,
lr: float = 1e-4,
weight_decay: float = 1e-4,
num_labels: int = 3,
):
"""Constructor, build model, save hyperparameters."""
super().__init__()
# get Mask R-CNN model
self.net = get_model_instance_segmentation(num_labels)
# hyperparameters
self.lr = lr
self.weight_decay = weight_decay
self.num_labels = num_labels
self.save_hyperparameters()
def forward(self, imgs, **kwargs):
"""Forward pass."""
return self.net(imgs, **kwargs)
def common_step(self, batchs, batch_idx):
# intialize outputs
outputs = {}
# for each dataloader
for dataloader_name, batch in batchs.items():
# extract pixel_values and labels from batch
images = batch["pixel_values"]
targets = batch["labels"]
# forward pass
model_output = self(images, targets=targets)
# get loss
outputs[dataloader_name] = {
"loss": sum(model_output.values()),
"loss_dict": model_output,
}
return outputs
def training_step(self, batch, batch_idx):
outputs = self.common_step(batch, batch_idx)
# logs metrics for each training_step
loss = 0
for dataloader_name, output in outputs.items():
loss += output["loss"]
self.log(f"train_loss_{dataloader_name}", output["loss"])
for k, v in output["loss_dict"].items():
self.log(f"train_loss_{k}_{dataloader_name}", v.item())
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
self.train()
with torch.no_grad():
outputs = self.common_step(batch, batch_idx)
self.eval()
# logs metrics for each validation_step
loss = 0
for dataloader_name, output in outputs.items():
loss += output["loss"]
self.log(f"val_loss_{dataloader_name}", output["loss"])
for k, v in output["loss_dict"].items():
self.log(f"val_loss_{k}_{dataloader_name}", v.item())
return loss
def configure_optimizers(self):
"""PyTorch optimizers and Schedulers.
Returns:
dictionnary for PyTorch Lightning optimizer/scheduler configuration
"""
optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
return {
"optimizer": optimizer,
}

2
src/module/__init__.py Normal file
View file

@ -0,0 +1,2 @@
from DETR import DETR
from FasterRCNN import FasterRCNN