refactor!: move files to python modules

feat: add back (broken) faster rcnn model
This commit is contained in:
Laurent Fainsin 2023-04-05 20:11:01 +02:00
parent 8f24f83b4a
commit d8bba4481a
12 changed files with 464 additions and 35 deletions

View file

@ -17,8 +17,8 @@ dependencies:
- pytorch - pytorch
- torchvision - torchvision
- torchaudio - torchaudio
- pytorch-cuda=11.8 - pytorch-cuda
- lightning - lightning # currently broken, install manually with pip
# deep learning libraries # deep learning libraries
- transformers - transformers
- datasets - datasets

View file

@ -1,7 +1,7 @@
import datasets import datasets
import torch import torch
from pytorch_lightning import LightningDataModule from lightning.pytorch import LightningDataModule
from pytorch_lightning.utilities import CombinedLoader from lightning.pytorch.utilities import CombinedLoader
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import AugMix from torchvision.transforms import AugMix
from transformers import DetrFeatureExtractor from transformers import DetrFeatureExtractor
@ -41,9 +41,9 @@ class DETRDataModule(LightningDataModule):
def prepare_data(self): def prepare_data(self):
"""Download data and prepare for training.""" """Download data and prepare for training."""
# load datasets # load datasets
self.illumination = datasets.load_dataset("src/spheres_illumination.py", split="train") self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
self.render = datasets.load_dataset("src/spheres_synth.py", split="train") self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
self.real = datasets.load_dataset("src/spheres.py", split="train") self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
# split datasets # split datasets
self.illumination = self.illumination.train_test_split(test_size=0.01) self.illumination = self.illumination.train_test_split(test_size=0.01)
@ -56,9 +56,7 @@ class DETRDataModule(LightningDataModule):
print(f"real: {self.real}") print(f"real: {self.real}")
# other datasets # other datasets
self.test_ds = datasets.load_dataset("src/spheres_illumination.py", split="test") self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
# self.predict_ds = datasets.load_dataset("src/spheres.py", split="train").shuffle().select(range(16))
self.predict_ds = datasets.load_dataset("src/spheres_predict.py", split="train")
# define AugMix transform # define AugMix transform
self.mix = AugMix() self.mix = AugMix()

View file

@ -0,0 +1,314 @@
import datasets
import torch
from lightning.pytorch import LightningDataModule
from lightning.pytorch.utilities import CombinedLoader
from torch.utils.data import DataLoader
from torchvision.transforms import AugMix
from transformers import DetrFeatureExtractor
class FasterRCNNDataModule(LightningDataModule):
"""PyTorch Lightning data module for Faster RCNN."""
def __init__(
self,
num_workers: int = 8,
batch_size: int = 5,
prefetch_factor: int = 2,
model_name: str = "facebook/detr-resnet-50",
persistent_workers: bool = True,
):
"""Constructor.
Args:
num_workers (int, optional): Number of workers.
batch_size (int, optional): Batch size.
prefetch_factor (int, optional): Prefetch factor.
val_split (float, optional): Validation split.
model_name (str, optional): Model name.
"""
super().__init__()
# save params
self.num_workers = num_workers
self.batch_size = batch_size
self.prefetch_factor = prefetch_factor
self.persistent_workers = persistent_workers
# get feature extractor
self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
def prepare_data(self):
"""Download data and prepare for training."""
# load datasets
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
# split datasets
self.illumination = self.illumination.train_test_split(test_size=0.01)
self.render = self.render.train_test_split(test_size=0.01)
self.real = self.real.train_test_split(test_size=0.1)
# print some info
print(f"illumination: {self.illumination}")
print(f"render: {self.render}")
print(f"real: {self.real}")
# other datasets
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
# define AugMix transform
self.mix = AugMix()
# useful mappings
self.labels = self.real["test"].features["objects"][0]["category_id"].names
self.id2label = {k: v for k, v in enumerate(self.labels)}
self.label2id = {v: k for k, v in enumerate(self.labels)}
def train_transform(self, batch):
"""Training transform.
Args:
batch (dict): Batch precollated by HuggingFace datasets.
Structure is similar to the following:
{
"image": list[PIL.Image],
"image_id": list[int],
"objects": [
{
"bbox": list[float, 4],
"category_id": int,
}
]
}
Returns:
dict: Augmented and processed batch.
Structure is similar to the following:
{
"pixel_values": TensorType["batch", "canal", "width", "height"],
"pixel_mask": TensorType["batch", 1200, 1200],
"labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]],
}
"""
# extract images, ids and objects from batch
images = batch["image"]
ids = batch["image_id"]
objects = batch["objects"]
# apply AugMix transform
images_mixed = [self.mix(image) for image in images]
# build targets for feature extractor
targets = [
{
"image_id": id,
"annotations": object,
}
for id, object in zip(ids, objects)
]
# process images and targets with feature extractor for DETR
processed = self.feature_extractor(
images=images_mixed,
annotations=targets,
return_tensors="pt",
)
for label in processed["labels"]:
# renamed "class_labels" to "labels"
# add 1 since 0 is reserved for background
label["labels"] = label["class_labels"] + 1
del label["class_labels"]
# format boxes from [xc, yc, w, h] to [x1, y1, x2, y2]
width_height = label["boxes"][:, 2:]
label["boxes"][:, :2] = label["boxes"][:, :2] - width_height / 2
label["boxes"][:, 2:] = label["boxes"][:, :2] + width_height / 2
# convert from normalized to absolute coordinates
label["boxes"][:, 0] *= label["size"][1]
label["boxes"][:, 1] *= label["size"][0]
label["boxes"][:, 2] *= label["size"][1]
label["boxes"][:, 3] *= label["size"][0]
return processed
def val_transform(self, batch):
"""Validation transform.
Just like Training transform, but without AugMix.
"""
# extract images, ids and objects from batch
images = batch["image"]
ids = batch["image_id"]
objects = batch["objects"]
# build targets for feature extractor
targets = [
{
"image_id": id,
"annotations": object,
}
for id, object in zip(ids, objects)
]
processed = self.feature_extractor(
images=images,
annotations=targets,
return_tensors="pt",
)
for label in processed["labels"]:
# renamed "class_labels" to "labels"
# add 1 since 0 is reserved for background
label["labels"] = label["class_labels"] + 1
del label["class_labels"]
# format boxes from [xcenter, ycenter, w, h] to [x1, y1, x2, y2]
center_x, center_y, width, height = label["boxes"].unbind(-1)
label["boxes"] = torch.stack(
# top left x, top left y, bottom right x, bottom right y
[
(center_x - 0.5 * width),
(center_y - 0.5 * height),
(center_x + 0.5 * width),
(center_y + 0.5 * height),
],
dim=-1,
)
# convert from normalized to absolute coordinates
label["boxes"][:, 0] *= label["size"][1]
label["boxes"][:, 1] *= label["size"][0]
label["boxes"][:, 2] *= label["size"][1]
label["boxes"][:, 3] *= label["size"][0]
return processed
def predict_transform(self, batch):
"""Prediction transform.
Just like val_transform, but with images.
"""
processed = self.val_transform(batch)
# add images to dict
processed["images"] = batch["image"]
return processed
def collate_fn(self, examples):
"""Collate function.
Convert list of dicts to dict of Tensors.
"""
return {
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
"labels": [data["labels"] for data in examples],
}
def collate_fn_predict(self, examples):
"""Collate function.
Convert list of dicts to dict of Tensors.
"""
return {
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
"labels": [data["labels"] for data in examples],
"images": [data["images"] for data in examples],
}
def train_dataloader(self):
"""Training dataloader."""
loaders = {
"illumination": DataLoader(
self.illumination["train"].with_transform(self.val_transform),
shuffle=True,
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"render": DataLoader(
self.render["train"].with_transform(self.val_transform),
shuffle=True,
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"real": DataLoader(
self.real["train"].with_transform(self.val_transform),
shuffle=True,
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
}
return CombinedLoader(loaders, mode="max_size_cycle")
def val_dataloader(self):
"""Validation dataloader."""
loaders = {
"illumination": DataLoader(
self.illumination["test"].with_transform(self.val_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"render": DataLoader(
self.render["test"].with_transform(self.val_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
"real": DataLoader(
self.real["test"].with_transform(self.val_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
),
}
return CombinedLoader(loaders, mode="max_size_cycle")
def predict_dataloader(self):
"""Prediction dataloader."""
return DataLoader(
self.predict_ds.with_transform(self.predict_transform),
pin_memory=True,
persistent_workers=self.persistent_workers,
collate_fn=self.collate_fn_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
if __name__ == "__main__":
# load data
dm = FasterRCNNDataModule()
dm.prepare_data()
ds = dm.train_dataloader()
for batch in ds:
print(batch)

View file

@ -0,0 +1,2 @@
from DETR import DETRDataModule
from FasterRCNN import FasterRCNNDataModule

View file

@ -7,9 +7,9 @@ import numpy as np
prefix = "/data/local-files/?d=spheres/" prefix = "/data/local-files/?d=spheres/"
dataset_path = pathlib.Path("./dataset_antoine_laurent/") dataset_path = pathlib.Path("./dataset_antoine_laurent/")
annotation_path = dataset_path / "annotations.json" annotation_path = dataset_path / "annotations.json" # from labelstudio
_VERSION = "1.0.0" _VERSION = "2.0.0"
_DESCRIPTION = "" _DESCRIPTION = ""
@ -105,6 +105,7 @@ class SphereAntoineLaurent(datasets.GeneratorBasedBuilder):
for annotation in annotations: for annotation in annotations:
assert annotation["value"]["keypointlabels"][0] == label assert annotation["value"]["keypointlabels"][0] == label
# convert labels
if label == "White": if label == "White":
label = "Matte" label = "Matte"
elif label == "Black": elif label == "Black":

View file

@ -5,7 +5,7 @@ import datasets
dataset_path_train = pathlib.Path("./dataset_illumination/") dataset_path_train = pathlib.Path("./dataset_illumination/")
_VERSION = "1.0.0" _VERSION = "2.0.0"
_DESCRIPTION = "" _DESCRIPTION = ""
@ -90,9 +90,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
meta["gray"]["bounding_box"]["x"] / original_width * width, meta["gray"]["bounding_box"]["x"] / original_width * width,
meta["gray"]["bounding_box"]["y"] / original_height * height, meta["gray"]["bounding_box"]["y"] / original_height * height,
meta["gray"]["bounding_box"]["w"] / original_width * width, meta["gray"]["bounding_box"]["w"] / original_width * width,
meta["gray"]["bounding_box"]["h"] / original_height * height meta["gray"]["bounding_box"]["h"] / original_height * height,
), ),
"Matte" "Matte",
) )
chrome = ( chrome = (
@ -100,9 +100,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
meta["chrome"]["bounding_box"]["x"] / original_width * width, meta["chrome"]["bounding_box"]["x"] / original_width * width,
meta["chrome"]["bounding_box"]["y"] / original_height * height, meta["chrome"]["bounding_box"]["y"] / original_height * height,
meta["chrome"]["bounding_box"]["w"] / original_width * width, meta["chrome"]["bounding_box"]["w"] / original_width * width,
meta["chrome"]["bounding_box"]["h"] / original_height * height meta["chrome"]["bounding_box"]["h"] / original_height * height,
), ),
"Chrome" "Chrome",
) )
# generate data # generate data

View file

@ -4,7 +4,7 @@ import datasets
dataset_path = pathlib.Path("./dataset_predict/") dataset_path = pathlib.Path("./dataset_predict/")
_VERSION = "1.0.0" _VERSION = "2.0.2"
_DESCRIPTION = "" _DESCRIPTION = ""
@ -60,7 +60,6 @@ class SpherePredict(datasets.GeneratorBasedBuilder):
jpgs = dataset_path.rglob("*.jpg") jpgs = dataset_path.rglob("*.jpg")
for index, jpg in enumerate(jpgs): for index, jpg in enumerate(jpgs):
print(index, jpg, 2) print(index, jpg, 2)
# generate data # generate data

View file

@ -4,7 +4,7 @@ import datasets
dataset_path = pathlib.Path("./dataset_render/") dataset_path = pathlib.Path("./dataset_render/")
_VERSION = "1.0.0" _VERSION = "2.0.0"
_DESCRIPTION = "" _DESCRIPTION = ""
@ -20,7 +20,6 @@ _NAMES = [
class SphereSynth(datasets.GeneratorBasedBuilder): class SphereSynth(datasets.GeneratorBasedBuilder):
def _info(self): def _info(self):
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@ -78,8 +77,9 @@ class SphereSynth(datasets.GeneratorBasedBuilder):
1 - float(line[1]), 1 - float(line[1]),
float(line[2]), float(line[2]),
1 - float(line[3]), 1 - float(line[3]),
line[4].strip() line[4].strip(),
) for line in lines )
for line in lines
] ]
bboxes = [ bboxes = [

View file

@ -1,11 +1,11 @@
from datamodule import DETRDataModule from datamodule import DETRDataModule, FasterRCNNDataModule # noqa: F401
from module import DETR from lightning.pytorch.callbacks import (
from pytorch_lightning.callbacks import (
ModelCheckpoint, ModelCheckpoint,
RichModelSummary, RichModelSummary,
RichProgressBar, RichProgressBar,
) )
from pytorch_lightning.cli import LightningCLI from lightning.pytorch.cli import LightningCLI
from module import DETR, FasterRCNN # noqa: F401
class MyLightningCLI(LightningCLI): class MyLightningCLI(LightningCLI):
@ -15,12 +15,11 @@ class MyLightningCLI(LightningCLI):
"""Add arguments to parser.""" """Add arguments to parser."""
parser.set_defaults( parser.set_defaults(
{ {
"trainer.multiple_trainloader_mode": "max_size_cycle",
"trainer.max_steps": 5000, "trainer.max_steps": 5000,
"trainer.max_epochs": 1, "trainer.max_epochs": 1,
"trainer.accelerator": "gpu", "trainer.accelerator": "gpu",
"trainer.devices": "[1]", "trainer.devices": "[0]",
"trainer.strategy": "dp", "trainer.strategy": "auto",
"trainer.log_every_n_steps": 25, "trainer.log_every_n_steps": 25,
"trainer.val_check_interval": 200, "trainer.val_check_interval": 200,
"trainer.num_sanity_val_steps": 10, "trainer.num_sanity_val_steps": 10,

View file

@ -1,6 +1,6 @@
import torch import torch
from lightning.pytorch import LightningModule
from PIL import ImageDraw from PIL import ImageDraw
from pytorch_lightning import LightningModule
from transformers import ( from transformers import (
DetrForObjectDetection, DetrForObjectDetection,
get_cosine_with_hard_restarts_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup,
@ -33,13 +33,14 @@ class DETR(LightningModule):
""" """
super().__init__() super().__init__()
# replace COCO classification head with custom head # get DETR model
self.net = DetrForObjectDetection.from_pretrained( self.net = DetrForObjectDetection.from_pretrained(
"facebook/detr-resnet-50", "facebook/detr-resnet-50",
ignore_mismatched_sizes=True, ignore_mismatched_sizes=True,
num_queries=num_queries, num_queries=num_queries,
num_labels=num_labels, num_labels=num_labels,
) )
torch.compile(self.net)
# cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896 # cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
self.lr = lr self.lr = lr
@ -47,7 +48,6 @@ class DETR(LightningModule):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.prediction_threshold = prediction_threshold self.prediction_threshold = prediction_threshold
self.save_hyperparameters() self.save_hyperparameters()
def forward(self, pixel_values, pixel_mask, **kwargs): def forward(self, pixel_values, pixel_mask, **kwargs):
@ -100,7 +100,7 @@ class DETR(LightningModule):
"""Training step.""" """Training step."""
outputs = self.common_step(batch, batch_idx) outputs = self.common_step(batch, batch_idx)
# logs metrics for each training_step and the average across the epoch # logs metrics for each training_step
loss = 0 loss = 0
for dataloader_name, output in outputs.items(): for dataloader_name, output in outputs.items():
loss += output["loss"] loss += output["loss"]
@ -117,7 +117,7 @@ class DETR(LightningModule):
"""Validation step.""" """Validation step."""
outputs = self.common_step(batch, batch_idx) outputs = self.common_step(batch, batch_idx)
# logs metrics for each validation_step and the average across the epoch # logs metrics for each validation_step
loss = 0 loss = 0
for dataloader_name, output in outputs.items(): for dataloader_name, output in outputs.items():
loss += output["loss"] loss += output["loss"]

114
src/module/FasterRCNN.py Normal file
View file

@ -0,0 +1,114 @@
import torch
import torchvision
from lightning.pytorch import LightningModule
from PIL import ImageDraw
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights, FastRCNNPredictor
def get_model_instance_segmentation(n_classes: int):
"""Returns a Torchvision FasterRCNN model for finetunning.
Args:
n_classes (int): number of classes the model should predict, background excluded
"""
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT,
box_detections_per_img=10, # cap numbers of detections, else oom
)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes + 1)
return model
class FasterRCNN(LightningModule):
"""Faster R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
def __init__(
self,
lr: float = 1e-4,
weight_decay: float = 1e-4,
num_labels: int = 3,
):
"""Constructor, build model, save hyperparameters."""
super().__init__()
# get Mask R-CNN model
self.net = get_model_instance_segmentation(num_labels)
# hyperparameters
self.lr = lr
self.weight_decay = weight_decay
self.num_labels = num_labels
self.save_hyperparameters()
def forward(self, imgs, **kwargs):
"""Forward pass."""
return self.net(imgs, **kwargs)
def common_step(self, batchs, batch_idx):
# intialize outputs
outputs = {}
# for each dataloader
for dataloader_name, batch in batchs.items():
# extract pixel_values and labels from batch
images = batch["pixel_values"]
targets = batch["labels"]
# forward pass
model_output = self(images, targets=targets)
# get loss
outputs[dataloader_name] = {
"loss": sum(model_output.values()),
"loss_dict": model_output,
}
return outputs
def training_step(self, batch, batch_idx):
outputs = self.common_step(batch, batch_idx)
# logs metrics for each training_step
loss = 0
for dataloader_name, output in outputs.items():
loss += output["loss"]
self.log(f"train_loss_{dataloader_name}", output["loss"])
for k, v in output["loss_dict"].items():
self.log(f"train_loss_{k}_{dataloader_name}", v.item())
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
self.train()
with torch.no_grad():
outputs = self.common_step(batch, batch_idx)
self.eval()
# logs metrics for each validation_step
loss = 0
for dataloader_name, output in outputs.items():
loss += output["loss"]
self.log(f"val_loss_{dataloader_name}", output["loss"])
for k, v in output["loss_dict"].items():
self.log(f"val_loss_{k}_{dataloader_name}", v.item())
return loss
def configure_optimizers(self):
"""PyTorch optimizers and Schedulers.
Returns:
dictionnary for PyTorch Lightning optimizer/scheduler configuration
"""
optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
return {
"optimizer": optimizer,
}

2
src/module/__init__.py Normal file
View file

@ -0,0 +1,2 @@
from DETR import DETR
from FasterRCNN import FasterRCNN