From d8bba4481a4146759fe3fcc93328e72ec4c53f45 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 5 Apr 2023 20:11:01 +0200 Subject: [PATCH] refactor!: move files to python modules feat: add back (broken) faster rcnn model --- env.yml => environment.yml | 4 +- src/{datamodule.py => datamodule/DETR.py} | 14 +- src/datamodule/FasterRCNN.py | 314 ++++++++++++++++++ src/datamodule/__init__.py | 2 + .../antoine_laurent.py} | 5 +- .../multi_illumination.py} | 10 +- .../predict.py} | 3 +- .../synthetic.py} | 8 +- src/main.py | 13 +- src/{module.py => module/DETR.py} | 10 +- src/module/FasterRCNN.py | 114 +++++++ src/module/__init__.py | 2 + 12 files changed, 464 insertions(+), 35 deletions(-) rename env.yml => environment.yml (84%) rename src/{datamodule.py => datamodule/DETR.py} (93%) create mode 100644 src/datamodule/FasterRCNN.py create mode 100644 src/datamodule/__init__.py rename src/{spheres_antoine_laurent.py => dataset/antoine_laurent.py} (98%) rename src/{spheres_illumination.py => dataset/multi_illumination.py} (97%) rename src/{spheres_predict.py => dataset/predict.py} (99%) rename src/{spheres_synth.py => dataset/synthetic.py} (97%) rename src/{module.py => module/DETR.py} (96%) create mode 100644 src/module/FasterRCNN.py create mode 100644 src/module/__init__.py diff --git a/env.yml b/environment.yml similarity index 84% rename from env.yml rename to environment.yml index ad80029..0c8d6b1 100644 --- a/env.yml +++ b/environment.yml @@ -17,8 +17,8 @@ dependencies: - pytorch - torchvision - torchaudio - - pytorch-cuda=11.8 - - lightning + - pytorch-cuda + - lightning # currently broken, install manually with pip # deep learning libraries - transformers - datasets diff --git a/src/datamodule.py b/src/datamodule/DETR.py similarity index 93% rename from src/datamodule.py rename to src/datamodule/DETR.py index 5db8b58..4f2489c 100644 --- a/src/datamodule.py +++ b/src/datamodule/DETR.py @@ -1,7 +1,7 @@ import datasets import torch -from pytorch_lightning import LightningDataModule -from pytorch_lightning.utilities import CombinedLoader +from lightning.pytorch import LightningDataModule +from lightning.pytorch.utilities import CombinedLoader from torch.utils.data import DataLoader from torchvision.transforms import AugMix from transformers import DetrFeatureExtractor @@ -41,9 +41,9 @@ class DETRDataModule(LightningDataModule): def prepare_data(self): """Download data and prepare for training.""" # load datasets - self.illumination = datasets.load_dataset("src/spheres_illumination.py", split="train") - self.render = datasets.load_dataset("src/spheres_synth.py", split="train") - self.real = datasets.load_dataset("src/spheres.py", split="train") + self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train") + self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train") + self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train") # split datasets self.illumination = self.illumination.train_test_split(test_size=0.01) @@ -56,9 +56,7 @@ class DETRDataModule(LightningDataModule): print(f"real: {self.real}") # other datasets - self.test_ds = datasets.load_dataset("src/spheres_illumination.py", split="test") - # self.predict_ds = datasets.load_dataset("src/spheres.py", split="train").shuffle().select(range(16)) - self.predict_ds = datasets.load_dataset("src/spheres_predict.py", split="train") + self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train") # define AugMix transform self.mix = AugMix() diff --git a/src/datamodule/FasterRCNN.py b/src/datamodule/FasterRCNN.py new file mode 100644 index 0000000..7371ffe --- /dev/null +++ b/src/datamodule/FasterRCNN.py @@ -0,0 +1,314 @@ +import datasets +import torch +from lightning.pytorch import LightningDataModule +from lightning.pytorch.utilities import CombinedLoader +from torch.utils.data import DataLoader +from torchvision.transforms import AugMix +from transformers import DetrFeatureExtractor + + +class FasterRCNNDataModule(LightningDataModule): + """PyTorch Lightning data module for Faster RCNN.""" + + def __init__( + self, + num_workers: int = 8, + batch_size: int = 5, + prefetch_factor: int = 2, + model_name: str = "facebook/detr-resnet-50", + persistent_workers: bool = True, + ): + """Constructor. + + Args: + num_workers (int, optional): Number of workers. + batch_size (int, optional): Batch size. + prefetch_factor (int, optional): Prefetch factor. + val_split (float, optional): Validation split. + model_name (str, optional): Model name. + """ + super().__init__() + + # save params + self.num_workers = num_workers + self.batch_size = batch_size + self.prefetch_factor = prefetch_factor + self.persistent_workers = persistent_workers + + # get feature extractor + self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) + + def prepare_data(self): + """Download data and prepare for training.""" + # load datasets + self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train") + self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train") + self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train") + + # split datasets + self.illumination = self.illumination.train_test_split(test_size=0.01) + self.render = self.render.train_test_split(test_size=0.01) + self.real = self.real.train_test_split(test_size=0.1) + + # print some info + print(f"illumination: {self.illumination}") + print(f"render: {self.render}") + print(f"real: {self.real}") + + # other datasets + self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train") + + # define AugMix transform + self.mix = AugMix() + + # useful mappings + self.labels = self.real["test"].features["objects"][0]["category_id"].names + self.id2label = {k: v for k, v in enumerate(self.labels)} + self.label2id = {v: k for k, v in enumerate(self.labels)} + + def train_transform(self, batch): + """Training transform. + + Args: + batch (dict): Batch precollated by HuggingFace datasets. + Structure is similar to the following: + { + "image": list[PIL.Image], + "image_id": list[int], + "objects": [ + { + "bbox": list[float, 4], + "category_id": int, + } + ] + } + + Returns: + dict: Augmented and processed batch. + Structure is similar to the following: + { + "pixel_values": TensorType["batch", "canal", "width", "height"], + "pixel_mask": TensorType["batch", 1200, 1200], + "labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]], + } + """ + # extract images, ids and objects from batch + images = batch["image"] + ids = batch["image_id"] + objects = batch["objects"] + + # apply AugMix transform + images_mixed = [self.mix(image) for image in images] + + # build targets for feature extractor + targets = [ + { + "image_id": id, + "annotations": object, + } + for id, object in zip(ids, objects) + ] + + # process images and targets with feature extractor for DETR + processed = self.feature_extractor( + images=images_mixed, + annotations=targets, + return_tensors="pt", + ) + + for label in processed["labels"]: + # renamed "class_labels" to "labels" + # add 1 since 0 is reserved for background + label["labels"] = label["class_labels"] + 1 + del label["class_labels"] + + # format boxes from [xc, yc, w, h] to [x1, y1, x2, y2] + width_height = label["boxes"][:, 2:] + label["boxes"][:, :2] = label["boxes"][:, :2] - width_height / 2 + label["boxes"][:, 2:] = label["boxes"][:, :2] + width_height / 2 + + # convert from normalized to absolute coordinates + label["boxes"][:, 0] *= label["size"][1] + label["boxes"][:, 1] *= label["size"][0] + label["boxes"][:, 2] *= label["size"][1] + label["boxes"][:, 3] *= label["size"][0] + + return processed + + def val_transform(self, batch): + """Validation transform. + + Just like Training transform, but without AugMix. + """ + # extract images, ids and objects from batch + images = batch["image"] + ids = batch["image_id"] + objects = batch["objects"] + + # build targets for feature extractor + targets = [ + { + "image_id": id, + "annotations": object, + } + for id, object in zip(ids, objects) + ] + + processed = self.feature_extractor( + images=images, + annotations=targets, + return_tensors="pt", + ) + + for label in processed["labels"]: + # renamed "class_labels" to "labels" + # add 1 since 0 is reserved for background + label["labels"] = label["class_labels"] + 1 + del label["class_labels"] + + # format boxes from [xcenter, ycenter, w, h] to [x1, y1, x2, y2] + center_x, center_y, width, height = label["boxes"].unbind(-1) + label["boxes"] = torch.stack( + # top left x, top left y, bottom right x, bottom right y + [ + (center_x - 0.5 * width), + (center_y - 0.5 * height), + (center_x + 0.5 * width), + (center_y + 0.5 * height), + ], + dim=-1, + ) + + # convert from normalized to absolute coordinates + label["boxes"][:, 0] *= label["size"][1] + label["boxes"][:, 1] *= label["size"][0] + label["boxes"][:, 2] *= label["size"][1] + label["boxes"][:, 3] *= label["size"][0] + + return processed + + def predict_transform(self, batch): + """Prediction transform. + + Just like val_transform, but with images. + """ + processed = self.val_transform(batch) + + # add images to dict + processed["images"] = batch["image"] + + return processed + + def collate_fn(self, examples): + """Collate function. + + Convert list of dicts to dict of Tensors. + """ + return { + "pixel_values": torch.stack([data["pixel_values"] for data in examples]), + "labels": [data["labels"] for data in examples], + } + + def collate_fn_predict(self, examples): + """Collate function. + + Convert list of dicts to dict of Tensors. + """ + return { + "pixel_values": torch.stack([data["pixel_values"] for data in examples]), + "labels": [data["labels"] for data in examples], + "images": [data["images"] for data in examples], + } + + def train_dataloader(self): + """Training dataloader.""" + loaders = { + "illumination": DataLoader( + self.illumination["train"].with_transform(self.val_transform), + shuffle=True, + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ), + "render": DataLoader( + self.render["train"].with_transform(self.val_transform), + shuffle=True, + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ), + "real": DataLoader( + self.real["train"].with_transform(self.val_transform), + shuffle=True, + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ), + } + + return CombinedLoader(loaders, mode="max_size_cycle") + + def val_dataloader(self): + """Validation dataloader.""" + loaders = { + "illumination": DataLoader( + self.illumination["test"].with_transform(self.val_transform), + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ), + "render": DataLoader( + self.render["test"].with_transform(self.val_transform), + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ), + "real": DataLoader( + self.real["test"].with_transform(self.val_transform), + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ), + } + + return CombinedLoader(loaders, mode="max_size_cycle") + + def predict_dataloader(self): + """Prediction dataloader.""" + return DataLoader( + self.predict_ds.with_transform(self.predict_transform), + pin_memory=True, + persistent_workers=self.persistent_workers, + collate_fn=self.collate_fn_predict, + batch_size=self.batch_size, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + +if __name__ == "__main__": + # load data + dm = FasterRCNNDataModule() + dm.prepare_data() + ds = dm.train_dataloader() + + for batch in ds: + print(batch) diff --git a/src/datamodule/__init__.py b/src/datamodule/__init__.py new file mode 100644 index 0000000..ef389a4 --- /dev/null +++ b/src/datamodule/__init__.py @@ -0,0 +1,2 @@ +from DETR import DETRDataModule +from FasterRCNN import FasterRCNNDataModule diff --git a/src/spheres_antoine_laurent.py b/src/dataset/antoine_laurent.py similarity index 98% rename from src/spheres_antoine_laurent.py rename to src/dataset/antoine_laurent.py index 5ac39d6..5826f8d 100644 --- a/src/spheres_antoine_laurent.py +++ b/src/dataset/antoine_laurent.py @@ -7,9 +7,9 @@ import numpy as np prefix = "/data/local-files/?d=spheres/" dataset_path = pathlib.Path("./dataset_antoine_laurent/") -annotation_path = dataset_path / "annotations.json" +annotation_path = dataset_path / "annotations.json" # from labelstudio -_VERSION = "1.0.0" +_VERSION = "2.0.0" _DESCRIPTION = "" @@ -105,6 +105,7 @@ class SphereAntoineLaurent(datasets.GeneratorBasedBuilder): for annotation in annotations: assert annotation["value"]["keypointlabels"][0] == label + # convert labels if label == "White": label = "Matte" elif label == "Black": diff --git a/src/spheres_illumination.py b/src/dataset/multi_illumination.py similarity index 97% rename from src/spheres_illumination.py rename to src/dataset/multi_illumination.py index 5c6828c..f700299 100644 --- a/src/spheres_illumination.py +++ b/src/dataset/multi_illumination.py @@ -5,7 +5,7 @@ import datasets dataset_path_train = pathlib.Path("./dataset_illumination/") -_VERSION = "1.0.0" +_VERSION = "2.0.0" _DESCRIPTION = "" @@ -90,9 +90,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder): meta["gray"]["bounding_box"]["x"] / original_width * width, meta["gray"]["bounding_box"]["y"] / original_height * height, meta["gray"]["bounding_box"]["w"] / original_width * width, - meta["gray"]["bounding_box"]["h"] / original_height * height + meta["gray"]["bounding_box"]["h"] / original_height * height, ), - "Matte" + "Matte", ) chrome = ( @@ -100,9 +100,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder): meta["chrome"]["bounding_box"]["x"] / original_width * width, meta["chrome"]["bounding_box"]["y"] / original_height * height, meta["chrome"]["bounding_box"]["w"] / original_width * width, - meta["chrome"]["bounding_box"]["h"] / original_height * height + meta["chrome"]["bounding_box"]["h"] / original_height * height, ), - "Chrome" + "Chrome", ) # generate data diff --git a/src/spheres_predict.py b/src/dataset/predict.py similarity index 99% rename from src/spheres_predict.py rename to src/dataset/predict.py index 28a2235..1ee5382 100644 --- a/src/spheres_predict.py +++ b/src/dataset/predict.py @@ -4,7 +4,7 @@ import datasets dataset_path = pathlib.Path("./dataset_predict/") -_VERSION = "1.0.0" +_VERSION = "2.0.2" _DESCRIPTION = "" @@ -60,7 +60,6 @@ class SpherePredict(datasets.GeneratorBasedBuilder): jpgs = dataset_path.rglob("*.jpg") for index, jpg in enumerate(jpgs): - print(index, jpg, 2) # generate data diff --git a/src/spheres_synth.py b/src/dataset/synthetic.py similarity index 97% rename from src/spheres_synth.py rename to src/dataset/synthetic.py index ed1402d..e8f17ce 100644 --- a/src/spheres_synth.py +++ b/src/dataset/synthetic.py @@ -4,7 +4,7 @@ import datasets dataset_path = pathlib.Path("./dataset_render/") -_VERSION = "1.0.0" +_VERSION = "2.0.0" _DESCRIPTION = "" @@ -20,7 +20,6 @@ _NAMES = [ class SphereSynth(datasets.GeneratorBasedBuilder): - def _info(self): return datasets.DatasetInfo( description=_DESCRIPTION, @@ -78,8 +77,9 @@ class SphereSynth(datasets.GeneratorBasedBuilder): 1 - float(line[1]), float(line[2]), 1 - float(line[3]), - line[4].strip() - ) for line in lines + line[4].strip(), + ) + for line in lines ] bboxes = [ diff --git a/src/main.py b/src/main.py index fd71af4..54218d0 100644 --- a/src/main.py +++ b/src/main.py @@ -1,11 +1,11 @@ -from datamodule import DETRDataModule -from module import DETR -from pytorch_lightning.callbacks import ( +from datamodule import DETRDataModule, FasterRCNNDataModule # noqa: F401 +from lightning.pytorch.callbacks import ( ModelCheckpoint, RichModelSummary, RichProgressBar, ) -from pytorch_lightning.cli import LightningCLI +from lightning.pytorch.cli import LightningCLI +from module import DETR, FasterRCNN # noqa: F401 class MyLightningCLI(LightningCLI): @@ -15,12 +15,11 @@ class MyLightningCLI(LightningCLI): """Add arguments to parser.""" parser.set_defaults( { - "trainer.multiple_trainloader_mode": "max_size_cycle", "trainer.max_steps": 5000, "trainer.max_epochs": 1, "trainer.accelerator": "gpu", - "trainer.devices": "[1]", - "trainer.strategy": "dp", + "trainer.devices": "[0]", + "trainer.strategy": "auto", "trainer.log_every_n_steps": 25, "trainer.val_check_interval": 200, "trainer.num_sanity_val_steps": 10, diff --git a/src/module.py b/src/module/DETR.py similarity index 96% rename from src/module.py rename to src/module/DETR.py index a745c39..5ae421c 100644 --- a/src/module.py +++ b/src/module/DETR.py @@ -1,6 +1,6 @@ import torch +from lightning.pytorch import LightningModule from PIL import ImageDraw -from pytorch_lightning import LightningModule from transformers import ( DetrForObjectDetection, get_cosine_with_hard_restarts_schedule_with_warmup, @@ -33,13 +33,14 @@ class DETR(LightningModule): """ super().__init__() - # replace COCO classification head with custom head + # get DETR model self.net = DetrForObjectDetection.from_pretrained( "facebook/detr-resnet-50", ignore_mismatched_sizes=True, num_queries=num_queries, num_labels=num_labels, ) + torch.compile(self.net) # cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896 self.lr = lr @@ -47,7 +48,6 @@ class DETR(LightningModule): self.weight_decay = weight_decay self.warmup_steps = warmup_steps self.prediction_threshold = prediction_threshold - self.save_hyperparameters() def forward(self, pixel_values, pixel_mask, **kwargs): @@ -100,7 +100,7 @@ class DETR(LightningModule): """Training step.""" outputs = self.common_step(batch, batch_idx) - # logs metrics for each training_step and the average across the epoch + # logs metrics for each training_step loss = 0 for dataloader_name, output in outputs.items(): loss += output["loss"] @@ -117,7 +117,7 @@ class DETR(LightningModule): """Validation step.""" outputs = self.common_step(batch, batch_idx) - # logs metrics for each validation_step and the average across the epoch + # logs metrics for each validation_step loss = 0 for dataloader_name, output in outputs.items(): loss += output["loss"] diff --git a/src/module/FasterRCNN.py b/src/module/FasterRCNN.py new file mode 100644 index 0000000..45adc5b --- /dev/null +++ b/src/module/FasterRCNN.py @@ -0,0 +1,114 @@ +import torch +import torchvision +from lightning.pytorch import LightningModule +from PIL import ImageDraw +from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights, FastRCNNPredictor + + +def get_model_instance_segmentation(n_classes: int): + """Returns a Torchvision FasterRCNN model for finetunning. + + Args: + n_classes (int): number of classes the model should predict, background excluded + """ + model = torchvision.models.detection.fasterrcnn_resnet50_fpn( + weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT, + box_detections_per_img=10, # cap numbers of detections, else oom + ) + + # get number of input features for the classifier + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes + 1) + + return model + + +class FasterRCNN(LightningModule): + """Faster R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions.""" + + def __init__( + self, + lr: float = 1e-4, + weight_decay: float = 1e-4, + num_labels: int = 3, + ): + """Constructor, build model, save hyperparameters.""" + super().__init__() + + # get Mask R-CNN model + self.net = get_model_instance_segmentation(num_labels) + + # hyperparameters + self.lr = lr + self.weight_decay = weight_decay + self.num_labels = num_labels + self.save_hyperparameters() + + def forward(self, imgs, **kwargs): + """Forward pass.""" + return self.net(imgs, **kwargs) + + def common_step(self, batchs, batch_idx): + # intialize outputs + outputs = {} + + # for each dataloader + for dataloader_name, batch in batchs.items(): + # extract pixel_values and labels from batch + images = batch["pixel_values"] + targets = batch["labels"] + + # forward pass + model_output = self(images, targets=targets) + + # get loss + outputs[dataloader_name] = { + "loss": sum(model_output.values()), + "loss_dict": model_output, + } + + return outputs + + def training_step(self, batch, batch_idx): + outputs = self.common_step(batch, batch_idx) + + # logs metrics for each training_step + loss = 0 + for dataloader_name, output in outputs.items(): + loss += output["loss"] + self.log(f"train_loss_{dataloader_name}", output["loss"]) + for k, v in output["loss_dict"].items(): + self.log(f"train_loss_{k}_{dataloader_name}", v.item()) + + self.log("lr", self.optimizers().param_groups[0]["lr"]) + + return loss + + def validation_step(self, batch, batch_idx): + self.train() + with torch.no_grad(): + outputs = self.common_step(batch, batch_idx) + self.eval() + + # logs metrics for each validation_step + loss = 0 + for dataloader_name, output in outputs.items(): + loss += output["loss"] + self.log(f"val_loss_{dataloader_name}", output["loss"]) + for k, v in output["loss_dict"].items(): + self.log(f"val_loss_{k}_{dataloader_name}", v.item()) + + return loss + + def configure_optimizers(self): + """PyTorch optimizers and Schedulers. + + Returns: + dictionnary for PyTorch Lightning optimizer/scheduler configuration + """ + optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + return { + "optimizer": optimizer, + } diff --git a/src/module/__init__.py b/src/module/__init__.py new file mode 100644 index 0000000..cc085db --- /dev/null +++ b/src/module/__init__.py @@ -0,0 +1,2 @@ +from DETR import DETR +from FasterRCNN import FasterRCNN