mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-09-18 19:15:27 +00:00
refactor!: move files to python modules
feat: add back (broken) faster rcnn model
This commit is contained in:
parent
8f24f83b4a
commit
d8bba4481a
|
@ -17,8 +17,8 @@ dependencies:
|
||||||
- pytorch
|
- pytorch
|
||||||
- torchvision
|
- torchvision
|
||||||
- torchaudio
|
- torchaudio
|
||||||
- pytorch-cuda=11.8
|
- pytorch-cuda
|
||||||
- lightning
|
- lightning # currently broken, install manually with pip
|
||||||
# deep learning libraries
|
# deep learning libraries
|
||||||
- transformers
|
- transformers
|
||||||
- datasets
|
- datasets
|
|
@ -1,7 +1,7 @@
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningDataModule
|
from lightning.pytorch import LightningDataModule
|
||||||
from pytorch_lightning.utilities import CombinedLoader
|
from lightning.pytorch.utilities import CombinedLoader
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import AugMix
|
from torchvision.transforms import AugMix
|
||||||
from transformers import DetrFeatureExtractor
|
from transformers import DetrFeatureExtractor
|
||||||
|
@ -41,9 +41,9 @@ class DETRDataModule(LightningDataModule):
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
"""Download data and prepare for training."""
|
"""Download data and prepare for training."""
|
||||||
# load datasets
|
# load datasets
|
||||||
self.illumination = datasets.load_dataset("src/spheres_illumination.py", split="train")
|
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
|
||||||
self.render = datasets.load_dataset("src/spheres_synth.py", split="train")
|
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
|
||||||
self.real = datasets.load_dataset("src/spheres.py", split="train")
|
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
|
||||||
|
|
||||||
# split datasets
|
# split datasets
|
||||||
self.illumination = self.illumination.train_test_split(test_size=0.01)
|
self.illumination = self.illumination.train_test_split(test_size=0.01)
|
||||||
|
@ -56,9 +56,7 @@ class DETRDataModule(LightningDataModule):
|
||||||
print(f"real: {self.real}")
|
print(f"real: {self.real}")
|
||||||
|
|
||||||
# other datasets
|
# other datasets
|
||||||
self.test_ds = datasets.load_dataset("src/spheres_illumination.py", split="test")
|
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
|
||||||
# self.predict_ds = datasets.load_dataset("src/spheres.py", split="train").shuffle().select(range(16))
|
|
||||||
self.predict_ds = datasets.load_dataset("src/spheres_predict.py", split="train")
|
|
||||||
|
|
||||||
# define AugMix transform
|
# define AugMix transform
|
||||||
self.mix = AugMix()
|
self.mix = AugMix()
|
314
src/datamodule/FasterRCNN.py
Normal file
314
src/datamodule/FasterRCNN.py
Normal file
|
@ -0,0 +1,314 @@
|
||||||
|
import datasets
|
||||||
|
import torch
|
||||||
|
from lightning.pytorch import LightningDataModule
|
||||||
|
from lightning.pytorch.utilities import CombinedLoader
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.transforms import AugMix
|
||||||
|
from transformers import DetrFeatureExtractor
|
||||||
|
|
||||||
|
|
||||||
|
class FasterRCNNDataModule(LightningDataModule):
|
||||||
|
"""PyTorch Lightning data module for Faster RCNN."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_workers: int = 8,
|
||||||
|
batch_size: int = 5,
|
||||||
|
prefetch_factor: int = 2,
|
||||||
|
model_name: str = "facebook/detr-resnet-50",
|
||||||
|
persistent_workers: bool = True,
|
||||||
|
):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_workers (int, optional): Number of workers.
|
||||||
|
batch_size (int, optional): Batch size.
|
||||||
|
prefetch_factor (int, optional): Prefetch factor.
|
||||||
|
val_split (float, optional): Validation split.
|
||||||
|
model_name (str, optional): Model name.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# save params
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.prefetch_factor = prefetch_factor
|
||||||
|
self.persistent_workers = persistent_workers
|
||||||
|
|
||||||
|
# get feature extractor
|
||||||
|
self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
"""Download data and prepare for training."""
|
||||||
|
# load datasets
|
||||||
|
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
|
||||||
|
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
|
||||||
|
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
|
||||||
|
|
||||||
|
# split datasets
|
||||||
|
self.illumination = self.illumination.train_test_split(test_size=0.01)
|
||||||
|
self.render = self.render.train_test_split(test_size=0.01)
|
||||||
|
self.real = self.real.train_test_split(test_size=0.1)
|
||||||
|
|
||||||
|
# print some info
|
||||||
|
print(f"illumination: {self.illumination}")
|
||||||
|
print(f"render: {self.render}")
|
||||||
|
print(f"real: {self.real}")
|
||||||
|
|
||||||
|
# other datasets
|
||||||
|
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
|
||||||
|
|
||||||
|
# define AugMix transform
|
||||||
|
self.mix = AugMix()
|
||||||
|
|
||||||
|
# useful mappings
|
||||||
|
self.labels = self.real["test"].features["objects"][0]["category_id"].names
|
||||||
|
self.id2label = {k: v for k, v in enumerate(self.labels)}
|
||||||
|
self.label2id = {v: k for k, v in enumerate(self.labels)}
|
||||||
|
|
||||||
|
def train_transform(self, batch):
|
||||||
|
"""Training transform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (dict): Batch precollated by HuggingFace datasets.
|
||||||
|
Structure is similar to the following:
|
||||||
|
{
|
||||||
|
"image": list[PIL.Image],
|
||||||
|
"image_id": list[int],
|
||||||
|
"objects": [
|
||||||
|
{
|
||||||
|
"bbox": list[float, 4],
|
||||||
|
"category_id": int,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Augmented and processed batch.
|
||||||
|
Structure is similar to the following:
|
||||||
|
{
|
||||||
|
"pixel_values": TensorType["batch", "canal", "width", "height"],
|
||||||
|
"pixel_mask": TensorType["batch", 1200, 1200],
|
||||||
|
"labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]],
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# extract images, ids and objects from batch
|
||||||
|
images = batch["image"]
|
||||||
|
ids = batch["image_id"]
|
||||||
|
objects = batch["objects"]
|
||||||
|
|
||||||
|
# apply AugMix transform
|
||||||
|
images_mixed = [self.mix(image) for image in images]
|
||||||
|
|
||||||
|
# build targets for feature extractor
|
||||||
|
targets = [
|
||||||
|
{
|
||||||
|
"image_id": id,
|
||||||
|
"annotations": object,
|
||||||
|
}
|
||||||
|
for id, object in zip(ids, objects)
|
||||||
|
]
|
||||||
|
|
||||||
|
# process images and targets with feature extractor for DETR
|
||||||
|
processed = self.feature_extractor(
|
||||||
|
images=images_mixed,
|
||||||
|
annotations=targets,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
for label in processed["labels"]:
|
||||||
|
# renamed "class_labels" to "labels"
|
||||||
|
# add 1 since 0 is reserved for background
|
||||||
|
label["labels"] = label["class_labels"] + 1
|
||||||
|
del label["class_labels"]
|
||||||
|
|
||||||
|
# format boxes from [xc, yc, w, h] to [x1, y1, x2, y2]
|
||||||
|
width_height = label["boxes"][:, 2:]
|
||||||
|
label["boxes"][:, :2] = label["boxes"][:, :2] - width_height / 2
|
||||||
|
label["boxes"][:, 2:] = label["boxes"][:, :2] + width_height / 2
|
||||||
|
|
||||||
|
# convert from normalized to absolute coordinates
|
||||||
|
label["boxes"][:, 0] *= label["size"][1]
|
||||||
|
label["boxes"][:, 1] *= label["size"][0]
|
||||||
|
label["boxes"][:, 2] *= label["size"][1]
|
||||||
|
label["boxes"][:, 3] *= label["size"][0]
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def val_transform(self, batch):
|
||||||
|
"""Validation transform.
|
||||||
|
|
||||||
|
Just like Training transform, but without AugMix.
|
||||||
|
"""
|
||||||
|
# extract images, ids and objects from batch
|
||||||
|
images = batch["image"]
|
||||||
|
ids = batch["image_id"]
|
||||||
|
objects = batch["objects"]
|
||||||
|
|
||||||
|
# build targets for feature extractor
|
||||||
|
targets = [
|
||||||
|
{
|
||||||
|
"image_id": id,
|
||||||
|
"annotations": object,
|
||||||
|
}
|
||||||
|
for id, object in zip(ids, objects)
|
||||||
|
]
|
||||||
|
|
||||||
|
processed = self.feature_extractor(
|
||||||
|
images=images,
|
||||||
|
annotations=targets,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
for label in processed["labels"]:
|
||||||
|
# renamed "class_labels" to "labels"
|
||||||
|
# add 1 since 0 is reserved for background
|
||||||
|
label["labels"] = label["class_labels"] + 1
|
||||||
|
del label["class_labels"]
|
||||||
|
|
||||||
|
# format boxes from [xcenter, ycenter, w, h] to [x1, y1, x2, y2]
|
||||||
|
center_x, center_y, width, height = label["boxes"].unbind(-1)
|
||||||
|
label["boxes"] = torch.stack(
|
||||||
|
# top left x, top left y, bottom right x, bottom right y
|
||||||
|
[
|
||||||
|
(center_x - 0.5 * width),
|
||||||
|
(center_y - 0.5 * height),
|
||||||
|
(center_x + 0.5 * width),
|
||||||
|
(center_y + 0.5 * height),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert from normalized to absolute coordinates
|
||||||
|
label["boxes"][:, 0] *= label["size"][1]
|
||||||
|
label["boxes"][:, 1] *= label["size"][0]
|
||||||
|
label["boxes"][:, 2] *= label["size"][1]
|
||||||
|
label["boxes"][:, 3] *= label["size"][0]
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def predict_transform(self, batch):
|
||||||
|
"""Prediction transform.
|
||||||
|
|
||||||
|
Just like val_transform, but with images.
|
||||||
|
"""
|
||||||
|
processed = self.val_transform(batch)
|
||||||
|
|
||||||
|
# add images to dict
|
||||||
|
processed["images"] = batch["image"]
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def collate_fn(self, examples):
|
||||||
|
"""Collate function.
|
||||||
|
|
||||||
|
Convert list of dicts to dict of Tensors.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
|
||||||
|
"labels": [data["labels"] for data in examples],
|
||||||
|
}
|
||||||
|
|
||||||
|
def collate_fn_predict(self, examples):
|
||||||
|
"""Collate function.
|
||||||
|
|
||||||
|
Convert list of dicts to dict of Tensors.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
|
||||||
|
"labels": [data["labels"] for data in examples],
|
||||||
|
"images": [data["images"] for data in examples],
|
||||||
|
}
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
"""Training dataloader."""
|
||||||
|
loaders = {
|
||||||
|
"illumination": DataLoader(
|
||||||
|
self.illumination["train"].with_transform(self.val_transform),
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
),
|
||||||
|
"render": DataLoader(
|
||||||
|
self.render["train"].with_transform(self.val_transform),
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
),
|
||||||
|
"real": DataLoader(
|
||||||
|
self.real["train"].with_transform(self.val_transform),
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return CombinedLoader(loaders, mode="max_size_cycle")
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
"""Validation dataloader."""
|
||||||
|
loaders = {
|
||||||
|
"illumination": DataLoader(
|
||||||
|
self.illumination["test"].with_transform(self.val_transform),
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
),
|
||||||
|
"render": DataLoader(
|
||||||
|
self.render["test"].with_transform(self.val_transform),
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
),
|
||||||
|
"real": DataLoader(
|
||||||
|
self.real["test"].with_transform(self.val_transform),
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return CombinedLoader(loaders, mode="max_size_cycle")
|
||||||
|
|
||||||
|
def predict_dataloader(self):
|
||||||
|
"""Prediction dataloader."""
|
||||||
|
return DataLoader(
|
||||||
|
self.predict_ds.with_transform(self.predict_transform),
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=self.persistent_workers,
|
||||||
|
collate_fn=self.collate_fn_predict,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
prefetch_factor=self.prefetch_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# load data
|
||||||
|
dm = FasterRCNNDataModule()
|
||||||
|
dm.prepare_data()
|
||||||
|
ds = dm.train_dataloader()
|
||||||
|
|
||||||
|
for batch in ds:
|
||||||
|
print(batch)
|
2
src/datamodule/__init__.py
Normal file
2
src/datamodule/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from DETR import DETRDataModule
|
||||||
|
from FasterRCNN import FasterRCNNDataModule
|
|
@ -7,9 +7,9 @@ import numpy as np
|
||||||
|
|
||||||
prefix = "/data/local-files/?d=spheres/"
|
prefix = "/data/local-files/?d=spheres/"
|
||||||
dataset_path = pathlib.Path("./dataset_antoine_laurent/")
|
dataset_path = pathlib.Path("./dataset_antoine_laurent/")
|
||||||
annotation_path = dataset_path / "annotations.json"
|
annotation_path = dataset_path / "annotations.json" # from labelstudio
|
||||||
|
|
||||||
_VERSION = "1.0.0"
|
_VERSION = "2.0.0"
|
||||||
|
|
||||||
_DESCRIPTION = ""
|
_DESCRIPTION = ""
|
||||||
|
|
||||||
|
@ -105,6 +105,7 @@ class SphereAntoineLaurent(datasets.GeneratorBasedBuilder):
|
||||||
for annotation in annotations:
|
for annotation in annotations:
|
||||||
assert annotation["value"]["keypointlabels"][0] == label
|
assert annotation["value"]["keypointlabels"][0] == label
|
||||||
|
|
||||||
|
# convert labels
|
||||||
if label == "White":
|
if label == "White":
|
||||||
label = "Matte"
|
label = "Matte"
|
||||||
elif label == "Black":
|
elif label == "Black":
|
|
@ -5,7 +5,7 @@ import datasets
|
||||||
|
|
||||||
dataset_path_train = pathlib.Path("./dataset_illumination/")
|
dataset_path_train = pathlib.Path("./dataset_illumination/")
|
||||||
|
|
||||||
_VERSION = "1.0.0"
|
_VERSION = "2.0.0"
|
||||||
|
|
||||||
_DESCRIPTION = ""
|
_DESCRIPTION = ""
|
||||||
|
|
||||||
|
@ -90,9 +90,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
|
||||||
meta["gray"]["bounding_box"]["x"] / original_width * width,
|
meta["gray"]["bounding_box"]["x"] / original_width * width,
|
||||||
meta["gray"]["bounding_box"]["y"] / original_height * height,
|
meta["gray"]["bounding_box"]["y"] / original_height * height,
|
||||||
meta["gray"]["bounding_box"]["w"] / original_width * width,
|
meta["gray"]["bounding_box"]["w"] / original_width * width,
|
||||||
meta["gray"]["bounding_box"]["h"] / original_height * height
|
meta["gray"]["bounding_box"]["h"] / original_height * height,
|
||||||
),
|
),
|
||||||
"Matte"
|
"Matte",
|
||||||
)
|
)
|
||||||
|
|
||||||
chrome = (
|
chrome = (
|
||||||
|
@ -100,9 +100,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
|
||||||
meta["chrome"]["bounding_box"]["x"] / original_width * width,
|
meta["chrome"]["bounding_box"]["x"] / original_width * width,
|
||||||
meta["chrome"]["bounding_box"]["y"] / original_height * height,
|
meta["chrome"]["bounding_box"]["y"] / original_height * height,
|
||||||
meta["chrome"]["bounding_box"]["w"] / original_width * width,
|
meta["chrome"]["bounding_box"]["w"] / original_width * width,
|
||||||
meta["chrome"]["bounding_box"]["h"] / original_height * height
|
meta["chrome"]["bounding_box"]["h"] / original_height * height,
|
||||||
),
|
),
|
||||||
"Chrome"
|
"Chrome",
|
||||||
)
|
)
|
||||||
|
|
||||||
# generate data
|
# generate data
|
|
@ -4,7 +4,7 @@ import datasets
|
||||||
|
|
||||||
dataset_path = pathlib.Path("./dataset_predict/")
|
dataset_path = pathlib.Path("./dataset_predict/")
|
||||||
|
|
||||||
_VERSION = "1.0.0"
|
_VERSION = "2.0.2"
|
||||||
|
|
||||||
_DESCRIPTION = ""
|
_DESCRIPTION = ""
|
||||||
|
|
||||||
|
@ -60,7 +60,6 @@ class SpherePredict(datasets.GeneratorBasedBuilder):
|
||||||
jpgs = dataset_path.rglob("*.jpg")
|
jpgs = dataset_path.rglob("*.jpg")
|
||||||
|
|
||||||
for index, jpg in enumerate(jpgs):
|
for index, jpg in enumerate(jpgs):
|
||||||
|
|
||||||
print(index, jpg, 2)
|
print(index, jpg, 2)
|
||||||
|
|
||||||
# generate data
|
# generate data
|
|
@ -4,7 +4,7 @@ import datasets
|
||||||
|
|
||||||
dataset_path = pathlib.Path("./dataset_render/")
|
dataset_path = pathlib.Path("./dataset_render/")
|
||||||
|
|
||||||
_VERSION = "1.0.0"
|
_VERSION = "2.0.0"
|
||||||
|
|
||||||
_DESCRIPTION = ""
|
_DESCRIPTION = ""
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ _NAMES = [
|
||||||
|
|
||||||
|
|
||||||
class SphereSynth(datasets.GeneratorBasedBuilder):
|
class SphereSynth(datasets.GeneratorBasedBuilder):
|
||||||
|
|
||||||
def _info(self):
|
def _info(self):
|
||||||
return datasets.DatasetInfo(
|
return datasets.DatasetInfo(
|
||||||
description=_DESCRIPTION,
|
description=_DESCRIPTION,
|
||||||
|
@ -78,8 +77,9 @@ class SphereSynth(datasets.GeneratorBasedBuilder):
|
||||||
1 - float(line[1]),
|
1 - float(line[1]),
|
||||||
float(line[2]),
|
float(line[2]),
|
||||||
1 - float(line[3]),
|
1 - float(line[3]),
|
||||||
line[4].strip()
|
line[4].strip(),
|
||||||
) for line in lines
|
)
|
||||||
|
for line in lines
|
||||||
]
|
]
|
||||||
|
|
||||||
bboxes = [
|
bboxes = [
|
13
src/main.py
13
src/main.py
|
@ -1,11 +1,11 @@
|
||||||
from datamodule import DETRDataModule
|
from datamodule import DETRDataModule, FasterRCNNDataModule # noqa: F401
|
||||||
from module import DETR
|
from lightning.pytorch.callbacks import (
|
||||||
from pytorch_lightning.callbacks import (
|
|
||||||
ModelCheckpoint,
|
ModelCheckpoint,
|
||||||
RichModelSummary,
|
RichModelSummary,
|
||||||
RichProgressBar,
|
RichProgressBar,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.cli import LightningCLI
|
from lightning.pytorch.cli import LightningCLI
|
||||||
|
from module import DETR, FasterRCNN # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class MyLightningCLI(LightningCLI):
|
class MyLightningCLI(LightningCLI):
|
||||||
|
@ -15,12 +15,11 @@ class MyLightningCLI(LightningCLI):
|
||||||
"""Add arguments to parser."""
|
"""Add arguments to parser."""
|
||||||
parser.set_defaults(
|
parser.set_defaults(
|
||||||
{
|
{
|
||||||
"trainer.multiple_trainloader_mode": "max_size_cycle",
|
|
||||||
"trainer.max_steps": 5000,
|
"trainer.max_steps": 5000,
|
||||||
"trainer.max_epochs": 1,
|
"trainer.max_epochs": 1,
|
||||||
"trainer.accelerator": "gpu",
|
"trainer.accelerator": "gpu",
|
||||||
"trainer.devices": "[1]",
|
"trainer.devices": "[0]",
|
||||||
"trainer.strategy": "dp",
|
"trainer.strategy": "auto",
|
||||||
"trainer.log_every_n_steps": 25,
|
"trainer.log_every_n_steps": 25,
|
||||||
"trainer.val_check_interval": 200,
|
"trainer.val_check_interval": 200,
|
||||||
"trainer.num_sanity_val_steps": 10,
|
"trainer.num_sanity_val_steps": 10,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
from lightning.pytorch import LightningModule
|
||||||
from PIL import ImageDraw
|
from PIL import ImageDraw
|
||||||
from pytorch_lightning import LightningModule
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
DetrForObjectDetection,
|
DetrForObjectDetection,
|
||||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
@ -33,13 +33,14 @@ class DETR(LightningModule):
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# replace COCO classification head with custom head
|
# get DETR model
|
||||||
self.net = DetrForObjectDetection.from_pretrained(
|
self.net = DetrForObjectDetection.from_pretrained(
|
||||||
"facebook/detr-resnet-50",
|
"facebook/detr-resnet-50",
|
||||||
ignore_mismatched_sizes=True,
|
ignore_mismatched_sizes=True,
|
||||||
num_queries=num_queries,
|
num_queries=num_queries,
|
||||||
num_labels=num_labels,
|
num_labels=num_labels,
|
||||||
)
|
)
|
||||||
|
torch.compile(self.net)
|
||||||
|
|
||||||
# cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
|
# cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
|
@ -47,7 +48,6 @@ class DETR(LightningModule):
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.warmup_steps = warmup_steps
|
self.warmup_steps = warmup_steps
|
||||||
self.prediction_threshold = prediction_threshold
|
self.prediction_threshold = prediction_threshold
|
||||||
|
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, pixel_values, pixel_mask, **kwargs):
|
def forward(self, pixel_values, pixel_mask, **kwargs):
|
||||||
|
@ -100,7 +100,7 @@ class DETR(LightningModule):
|
||||||
"""Training step."""
|
"""Training step."""
|
||||||
outputs = self.common_step(batch, batch_idx)
|
outputs = self.common_step(batch, batch_idx)
|
||||||
|
|
||||||
# logs metrics for each training_step and the average across the epoch
|
# logs metrics for each training_step
|
||||||
loss = 0
|
loss = 0
|
||||||
for dataloader_name, output in outputs.items():
|
for dataloader_name, output in outputs.items():
|
||||||
loss += output["loss"]
|
loss += output["loss"]
|
||||||
|
@ -117,7 +117,7 @@ class DETR(LightningModule):
|
||||||
"""Validation step."""
|
"""Validation step."""
|
||||||
outputs = self.common_step(batch, batch_idx)
|
outputs = self.common_step(batch, batch_idx)
|
||||||
|
|
||||||
# logs metrics for each validation_step and the average across the epoch
|
# logs metrics for each validation_step
|
||||||
loss = 0
|
loss = 0
|
||||||
for dataloader_name, output in outputs.items():
|
for dataloader_name, output in outputs.items():
|
||||||
loss += output["loss"]
|
loss += output["loss"]
|
114
src/module/FasterRCNN.py
Normal file
114
src/module/FasterRCNN.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from lightning.pytorch import LightningModule
|
||||||
|
from PIL import ImageDraw
|
||||||
|
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights, FastRCNNPredictor
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_instance_segmentation(n_classes: int):
|
||||||
|
"""Returns a Torchvision FasterRCNN model for finetunning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes (int): number of classes the model should predict, background excluded
|
||||||
|
"""
|
||||||
|
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
|
||||||
|
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT,
|
||||||
|
box_detections_per_img=10, # cap numbers of detections, else oom
|
||||||
|
)
|
||||||
|
|
||||||
|
# get number of input features for the classifier
|
||||||
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
# replace the pre-trained head with a new one
|
||||||
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes + 1)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class FasterRCNN(LightningModule):
|
||||||
|
"""Faster R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-4,
|
||||||
|
num_labels: int = 3,
|
||||||
|
):
|
||||||
|
"""Constructor, build model, save hyperparameters."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# get Mask R-CNN model
|
||||||
|
self.net = get_model_instance_segmentation(num_labels)
|
||||||
|
|
||||||
|
# hyperparameters
|
||||||
|
self.lr = lr
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
def forward(self, imgs, **kwargs):
|
||||||
|
"""Forward pass."""
|
||||||
|
return self.net(imgs, **kwargs)
|
||||||
|
|
||||||
|
def common_step(self, batchs, batch_idx):
|
||||||
|
# intialize outputs
|
||||||
|
outputs = {}
|
||||||
|
|
||||||
|
# for each dataloader
|
||||||
|
for dataloader_name, batch in batchs.items():
|
||||||
|
# extract pixel_values and labels from batch
|
||||||
|
images = batch["pixel_values"]
|
||||||
|
targets = batch["labels"]
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
model_output = self(images, targets=targets)
|
||||||
|
|
||||||
|
# get loss
|
||||||
|
outputs[dataloader_name] = {
|
||||||
|
"loss": sum(model_output.values()),
|
||||||
|
"loss_dict": model_output,
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
outputs = self.common_step(batch, batch_idx)
|
||||||
|
|
||||||
|
# logs metrics for each training_step
|
||||||
|
loss = 0
|
||||||
|
for dataloader_name, output in outputs.items():
|
||||||
|
loss += output["loss"]
|
||||||
|
self.log(f"train_loss_{dataloader_name}", output["loss"])
|
||||||
|
for k, v in output["loss_dict"].items():
|
||||||
|
self.log(f"train_loss_{k}_{dataloader_name}", v.item())
|
||||||
|
|
||||||
|
self.log("lr", self.optimizers().param_groups[0]["lr"])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
self.train()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.common_step(batch, batch_idx)
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
# logs metrics for each validation_step
|
||||||
|
loss = 0
|
||||||
|
for dataloader_name, output in outputs.items():
|
||||||
|
loss += output["loss"]
|
||||||
|
self.log(f"val_loss_{dataloader_name}", output["loss"])
|
||||||
|
for k, v in output["loss_dict"].items():
|
||||||
|
self.log(f"val_loss_{k}_{dataloader_name}", v.item())
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""PyTorch optimizers and Schedulers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionnary for PyTorch Lightning optimizer/scheduler configuration
|
||||||
|
"""
|
||||||
|
optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
}
|
2
src/module/__init__.py
Normal file
2
src/module/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from DETR import DETR
|
||||||
|
from FasterRCNN import FasterRCNN
|
Loading…
Reference in a new issue