mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-09-16 18:15:28 +00:00
refactor!: move files to python modules
feat: add back (broken) faster rcnn model
This commit is contained in:
parent
8f24f83b4a
commit
d8bba4481a
|
@ -17,8 +17,8 @@ dependencies:
|
|||
- pytorch
|
||||
- torchvision
|
||||
- torchaudio
|
||||
- pytorch-cuda=11.8
|
||||
- lightning
|
||||
- pytorch-cuda
|
||||
- lightning # currently broken, install manually with pip
|
||||
# deep learning libraries
|
||||
- transformers
|
||||
- datasets
|
|
@ -1,7 +1,7 @@
|
|||
import datasets
|
||||
import torch
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from pytorch_lightning.utilities import CombinedLoader
|
||||
from lightning.pytorch import LightningDataModule
|
||||
from lightning.pytorch.utilities import CombinedLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import AugMix
|
||||
from transformers import DetrFeatureExtractor
|
||||
|
@ -41,9 +41,9 @@ class DETRDataModule(LightningDataModule):
|
|||
def prepare_data(self):
|
||||
"""Download data and prepare for training."""
|
||||
# load datasets
|
||||
self.illumination = datasets.load_dataset("src/spheres_illumination.py", split="train")
|
||||
self.render = datasets.load_dataset("src/spheres_synth.py", split="train")
|
||||
self.real = datasets.load_dataset("src/spheres.py", split="train")
|
||||
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
|
||||
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
|
||||
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
|
||||
|
||||
# split datasets
|
||||
self.illumination = self.illumination.train_test_split(test_size=0.01)
|
||||
|
@ -56,9 +56,7 @@ class DETRDataModule(LightningDataModule):
|
|||
print(f"real: {self.real}")
|
||||
|
||||
# other datasets
|
||||
self.test_ds = datasets.load_dataset("src/spheres_illumination.py", split="test")
|
||||
# self.predict_ds = datasets.load_dataset("src/spheres.py", split="train").shuffle().select(range(16))
|
||||
self.predict_ds = datasets.load_dataset("src/spheres_predict.py", split="train")
|
||||
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
|
||||
|
||||
# define AugMix transform
|
||||
self.mix = AugMix()
|
314
src/datamodule/FasterRCNN.py
Normal file
314
src/datamodule/FasterRCNN.py
Normal file
|
@ -0,0 +1,314 @@
|
|||
import datasets
|
||||
import torch
|
||||
from lightning.pytorch import LightningDataModule
|
||||
from lightning.pytorch.utilities import CombinedLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import AugMix
|
||||
from transformers import DetrFeatureExtractor
|
||||
|
||||
|
||||
class FasterRCNNDataModule(LightningDataModule):
|
||||
"""PyTorch Lightning data module for Faster RCNN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_workers: int = 8,
|
||||
batch_size: int = 5,
|
||||
prefetch_factor: int = 2,
|
||||
model_name: str = "facebook/detr-resnet-50",
|
||||
persistent_workers: bool = True,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
num_workers (int, optional): Number of workers.
|
||||
batch_size (int, optional): Batch size.
|
||||
prefetch_factor (int, optional): Prefetch factor.
|
||||
val_split (float, optional): Validation split.
|
||||
model_name (str, optional): Model name.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# save params
|
||||
self.num_workers = num_workers
|
||||
self.batch_size = batch_size
|
||||
self.prefetch_factor = prefetch_factor
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
# get feature extractor
|
||||
self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
||||
|
||||
def prepare_data(self):
|
||||
"""Download data and prepare for training."""
|
||||
# load datasets
|
||||
self.illumination = datasets.load_dataset("src/dataset/multi_illumination.py", split="train")
|
||||
self.render = datasets.load_dataset("src/dataset/synthetic.py", split="train")
|
||||
self.real = datasets.load_dataset("src/dataset/antoine_laurent.py", split="train")
|
||||
|
||||
# split datasets
|
||||
self.illumination = self.illumination.train_test_split(test_size=0.01)
|
||||
self.render = self.render.train_test_split(test_size=0.01)
|
||||
self.real = self.real.train_test_split(test_size=0.1)
|
||||
|
||||
# print some info
|
||||
print(f"illumination: {self.illumination}")
|
||||
print(f"render: {self.render}")
|
||||
print(f"real: {self.real}")
|
||||
|
||||
# other datasets
|
||||
self.predict_ds = datasets.load_dataset("src/dataset/predict.py", split="train")
|
||||
|
||||
# define AugMix transform
|
||||
self.mix = AugMix()
|
||||
|
||||
# useful mappings
|
||||
self.labels = self.real["test"].features["objects"][0]["category_id"].names
|
||||
self.id2label = {k: v for k, v in enumerate(self.labels)}
|
||||
self.label2id = {v: k for k, v in enumerate(self.labels)}
|
||||
|
||||
def train_transform(self, batch):
|
||||
"""Training transform.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch precollated by HuggingFace datasets.
|
||||
Structure is similar to the following:
|
||||
{
|
||||
"image": list[PIL.Image],
|
||||
"image_id": list[int],
|
||||
"objects": [
|
||||
{
|
||||
"bbox": list[float, 4],
|
||||
"category_id": int,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Returns:
|
||||
dict: Augmented and processed batch.
|
||||
Structure is similar to the following:
|
||||
{
|
||||
"pixel_values": TensorType["batch", "canal", "width", "height"],
|
||||
"pixel_mask": TensorType["batch", 1200, 1200],
|
||||
"labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]],
|
||||
}
|
||||
"""
|
||||
# extract images, ids and objects from batch
|
||||
images = batch["image"]
|
||||
ids = batch["image_id"]
|
||||
objects = batch["objects"]
|
||||
|
||||
# apply AugMix transform
|
||||
images_mixed = [self.mix(image) for image in images]
|
||||
|
||||
# build targets for feature extractor
|
||||
targets = [
|
||||
{
|
||||
"image_id": id,
|
||||
"annotations": object,
|
||||
}
|
||||
for id, object in zip(ids, objects)
|
||||
]
|
||||
|
||||
# process images and targets with feature extractor for DETR
|
||||
processed = self.feature_extractor(
|
||||
images=images_mixed,
|
||||
annotations=targets,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
for label in processed["labels"]:
|
||||
# renamed "class_labels" to "labels"
|
||||
# add 1 since 0 is reserved for background
|
||||
label["labels"] = label["class_labels"] + 1
|
||||
del label["class_labels"]
|
||||
|
||||
# format boxes from [xc, yc, w, h] to [x1, y1, x2, y2]
|
||||
width_height = label["boxes"][:, 2:]
|
||||
label["boxes"][:, :2] = label["boxes"][:, :2] - width_height / 2
|
||||
label["boxes"][:, 2:] = label["boxes"][:, :2] + width_height / 2
|
||||
|
||||
# convert from normalized to absolute coordinates
|
||||
label["boxes"][:, 0] *= label["size"][1]
|
||||
label["boxes"][:, 1] *= label["size"][0]
|
||||
label["boxes"][:, 2] *= label["size"][1]
|
||||
label["boxes"][:, 3] *= label["size"][0]
|
||||
|
||||
return processed
|
||||
|
||||
def val_transform(self, batch):
|
||||
"""Validation transform.
|
||||
|
||||
Just like Training transform, but without AugMix.
|
||||
"""
|
||||
# extract images, ids and objects from batch
|
||||
images = batch["image"]
|
||||
ids = batch["image_id"]
|
||||
objects = batch["objects"]
|
||||
|
||||
# build targets for feature extractor
|
||||
targets = [
|
||||
{
|
||||
"image_id": id,
|
||||
"annotations": object,
|
||||
}
|
||||
for id, object in zip(ids, objects)
|
||||
]
|
||||
|
||||
processed = self.feature_extractor(
|
||||
images=images,
|
||||
annotations=targets,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
for label in processed["labels"]:
|
||||
# renamed "class_labels" to "labels"
|
||||
# add 1 since 0 is reserved for background
|
||||
label["labels"] = label["class_labels"] + 1
|
||||
del label["class_labels"]
|
||||
|
||||
# format boxes from [xcenter, ycenter, w, h] to [x1, y1, x2, y2]
|
||||
center_x, center_y, width, height = label["boxes"].unbind(-1)
|
||||
label["boxes"] = torch.stack(
|
||||
# top left x, top left y, bottom right x, bottom right y
|
||||
[
|
||||
(center_x - 0.5 * width),
|
||||
(center_y - 0.5 * height),
|
||||
(center_x + 0.5 * width),
|
||||
(center_y + 0.5 * height),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# convert from normalized to absolute coordinates
|
||||
label["boxes"][:, 0] *= label["size"][1]
|
||||
label["boxes"][:, 1] *= label["size"][0]
|
||||
label["boxes"][:, 2] *= label["size"][1]
|
||||
label["boxes"][:, 3] *= label["size"][0]
|
||||
|
||||
return processed
|
||||
|
||||
def predict_transform(self, batch):
|
||||
"""Prediction transform.
|
||||
|
||||
Just like val_transform, but with images.
|
||||
"""
|
||||
processed = self.val_transform(batch)
|
||||
|
||||
# add images to dict
|
||||
processed["images"] = batch["image"]
|
||||
|
||||
return processed
|
||||
|
||||
def collate_fn(self, examples):
|
||||
"""Collate function.
|
||||
|
||||
Convert list of dicts to dict of Tensors.
|
||||
"""
|
||||
return {
|
||||
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
|
||||
"labels": [data["labels"] for data in examples],
|
||||
}
|
||||
|
||||
def collate_fn_predict(self, examples):
|
||||
"""Collate function.
|
||||
|
||||
Convert list of dicts to dict of Tensors.
|
||||
"""
|
||||
return {
|
||||
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
|
||||
"labels": [data["labels"] for data in examples],
|
||||
"images": [data["images"] for data in examples],
|
||||
}
|
||||
|
||||
def train_dataloader(self):
|
||||
"""Training dataloader."""
|
||||
loaders = {
|
||||
"illumination": DataLoader(
|
||||
self.illumination["train"].with_transform(self.val_transform),
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"render": DataLoader(
|
||||
self.render["train"].with_transform(self.val_transform),
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"real": DataLoader(
|
||||
self.real["train"].with_transform(self.val_transform),
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
}
|
||||
|
||||
return CombinedLoader(loaders, mode="max_size_cycle")
|
||||
|
||||
def val_dataloader(self):
|
||||
"""Validation dataloader."""
|
||||
loaders = {
|
||||
"illumination": DataLoader(
|
||||
self.illumination["test"].with_transform(self.val_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"render": DataLoader(
|
||||
self.render["test"].with_transform(self.val_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"real": DataLoader(
|
||||
self.real["test"].with_transform(self.val_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
}
|
||||
|
||||
return CombinedLoader(loaders, mode="max_size_cycle")
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""Prediction dataloader."""
|
||||
return DataLoader(
|
||||
self.predict_ds.with_transform(self.predict_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn_predict,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load data
|
||||
dm = FasterRCNNDataModule()
|
||||
dm.prepare_data()
|
||||
ds = dm.train_dataloader()
|
||||
|
||||
for batch in ds:
|
||||
print(batch)
|
2
src/datamodule/__init__.py
Normal file
2
src/datamodule/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from DETR import DETRDataModule
|
||||
from FasterRCNN import FasterRCNNDataModule
|
|
@ -7,9 +7,9 @@ import numpy as np
|
|||
|
||||
prefix = "/data/local-files/?d=spheres/"
|
||||
dataset_path = pathlib.Path("./dataset_antoine_laurent/")
|
||||
annotation_path = dataset_path / "annotations.json"
|
||||
annotation_path = dataset_path / "annotations.json" # from labelstudio
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
_VERSION = "2.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
|
@ -105,6 +105,7 @@ class SphereAntoineLaurent(datasets.GeneratorBasedBuilder):
|
|||
for annotation in annotations:
|
||||
assert annotation["value"]["keypointlabels"][0] == label
|
||||
|
||||
# convert labels
|
||||
if label == "White":
|
||||
label = "Matte"
|
||||
elif label == "Black":
|
|
@ -5,7 +5,7 @@ import datasets
|
|||
|
||||
dataset_path_train = pathlib.Path("./dataset_illumination/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
_VERSION = "2.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
|
@ -90,9 +90,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
|
|||
meta["gray"]["bounding_box"]["x"] / original_width * width,
|
||||
meta["gray"]["bounding_box"]["y"] / original_height * height,
|
||||
meta["gray"]["bounding_box"]["w"] / original_width * width,
|
||||
meta["gray"]["bounding_box"]["h"] / original_height * height
|
||||
meta["gray"]["bounding_box"]["h"] / original_height * height,
|
||||
),
|
||||
"Matte"
|
||||
"Matte",
|
||||
)
|
||||
|
||||
chrome = (
|
||||
|
@ -100,9 +100,9 @@ class SphereIllumination(datasets.GeneratorBasedBuilder):
|
|||
meta["chrome"]["bounding_box"]["x"] / original_width * width,
|
||||
meta["chrome"]["bounding_box"]["y"] / original_height * height,
|
||||
meta["chrome"]["bounding_box"]["w"] / original_width * width,
|
||||
meta["chrome"]["bounding_box"]["h"] / original_height * height
|
||||
meta["chrome"]["bounding_box"]["h"] / original_height * height,
|
||||
),
|
||||
"Chrome"
|
||||
"Chrome",
|
||||
)
|
||||
|
||||
# generate data
|
|
@ -4,7 +4,7 @@ import datasets
|
|||
|
||||
dataset_path = pathlib.Path("./dataset_predict/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
_VERSION = "2.0.2"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
|
@ -60,7 +60,6 @@ class SpherePredict(datasets.GeneratorBasedBuilder):
|
|||
jpgs = dataset_path.rglob("*.jpg")
|
||||
|
||||
for index, jpg in enumerate(jpgs):
|
||||
|
||||
print(index, jpg, 2)
|
||||
|
||||
# generate data
|
|
@ -4,7 +4,7 @@ import datasets
|
|||
|
||||
dataset_path = pathlib.Path("./dataset_render/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
_VERSION = "2.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
|
@ -20,7 +20,6 @@ _NAMES = [
|
|||
|
||||
|
||||
class SphereSynth(datasets.GeneratorBasedBuilder):
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
|
@ -78,8 +77,9 @@ class SphereSynth(datasets.GeneratorBasedBuilder):
|
|||
1 - float(line[1]),
|
||||
float(line[2]),
|
||||
1 - float(line[3]),
|
||||
line[4].strip()
|
||||
) for line in lines
|
||||
line[4].strip(),
|
||||
)
|
||||
for line in lines
|
||||
]
|
||||
|
||||
bboxes = [
|
13
src/main.py
13
src/main.py
|
@ -1,11 +1,11 @@
|
|||
from datamodule import DETRDataModule
|
||||
from module import DETR
|
||||
from pytorch_lightning.callbacks import (
|
||||
from datamodule import DETRDataModule, FasterRCNNDataModule # noqa: F401
|
||||
from lightning.pytorch.callbacks import (
|
||||
ModelCheckpoint,
|
||||
RichModelSummary,
|
||||
RichProgressBar,
|
||||
)
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from lightning.pytorch.cli import LightningCLI
|
||||
from module import DETR, FasterRCNN # noqa: F401
|
||||
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
|
@ -15,12 +15,11 @@ class MyLightningCLI(LightningCLI):
|
|||
"""Add arguments to parser."""
|
||||
parser.set_defaults(
|
||||
{
|
||||
"trainer.multiple_trainloader_mode": "max_size_cycle",
|
||||
"trainer.max_steps": 5000,
|
||||
"trainer.max_epochs": 1,
|
||||
"trainer.accelerator": "gpu",
|
||||
"trainer.devices": "[1]",
|
||||
"trainer.strategy": "dp",
|
||||
"trainer.devices": "[0]",
|
||||
"trainer.strategy": "auto",
|
||||
"trainer.log_every_n_steps": 25,
|
||||
"trainer.val_check_interval": 200,
|
||||
"trainer.num_sanity_val_steps": 10,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from lightning.pytorch import LightningModule
|
||||
from PIL import ImageDraw
|
||||
from pytorch_lightning import LightningModule
|
||||
from transformers import (
|
||||
DetrForObjectDetection,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
|
@ -33,13 +33,14 @@ class DETR(LightningModule):
|
|||
"""
|
||||
super().__init__()
|
||||
|
||||
# replace COCO classification head with custom head
|
||||
# get DETR model
|
||||
self.net = DetrForObjectDetection.from_pretrained(
|
||||
"facebook/detr-resnet-50",
|
||||
ignore_mismatched_sizes=True,
|
||||
num_queries=num_queries,
|
||||
num_labels=num_labels,
|
||||
)
|
||||
torch.compile(self.net)
|
||||
|
||||
# cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
|
||||
self.lr = lr
|
||||
|
@ -47,7 +48,6 @@ class DETR(LightningModule):
|
|||
self.weight_decay = weight_decay
|
||||
self.warmup_steps = warmup_steps
|
||||
self.prediction_threshold = prediction_threshold
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, pixel_values, pixel_mask, **kwargs):
|
||||
|
@ -100,7 +100,7 @@ class DETR(LightningModule):
|
|||
"""Training step."""
|
||||
outputs = self.common_step(batch, batch_idx)
|
||||
|
||||
# logs metrics for each training_step and the average across the epoch
|
||||
# logs metrics for each training_step
|
||||
loss = 0
|
||||
for dataloader_name, output in outputs.items():
|
||||
loss += output["loss"]
|
||||
|
@ -117,7 +117,7 @@ class DETR(LightningModule):
|
|||
"""Validation step."""
|
||||
outputs = self.common_step(batch, batch_idx)
|
||||
|
||||
# logs metrics for each validation_step and the average across the epoch
|
||||
# logs metrics for each validation_step
|
||||
loss = 0
|
||||
for dataloader_name, output in outputs.items():
|
||||
loss += output["loss"]
|
114
src/module/FasterRCNN.py
Normal file
114
src/module/FasterRCNN.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
import torch
|
||||
import torchvision
|
||||
from lightning.pytorch import LightningModule
|
||||
from PIL import ImageDraw
|
||||
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights, FastRCNNPredictor
|
||||
|
||||
|
||||
def get_model_instance_segmentation(n_classes: int):
|
||||
"""Returns a Torchvision FasterRCNN model for finetunning.
|
||||
|
||||
Args:
|
||||
n_classes (int): number of classes the model should predict, background excluded
|
||||
"""
|
||||
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
|
||||
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT,
|
||||
box_detections_per_img=10, # cap numbers of detections, else oom
|
||||
)
|
||||
|
||||
# get number of input features for the classifier
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
# replace the pre-trained head with a new one
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes + 1)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class FasterRCNN(LightningModule):
|
||||
"""Faster R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-4,
|
||||
num_labels: int = 3,
|
||||
):
|
||||
"""Constructor, build model, save hyperparameters."""
|
||||
super().__init__()
|
||||
|
||||
# get Mask R-CNN model
|
||||
self.net = get_model_instance_segmentation(num_labels)
|
||||
|
||||
# hyperparameters
|
||||
self.lr = lr
|
||||
self.weight_decay = weight_decay
|
||||
self.num_labels = num_labels
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, imgs, **kwargs):
|
||||
"""Forward pass."""
|
||||
return self.net(imgs, **kwargs)
|
||||
|
||||
def common_step(self, batchs, batch_idx):
|
||||
# intialize outputs
|
||||
outputs = {}
|
||||
|
||||
# for each dataloader
|
||||
for dataloader_name, batch in batchs.items():
|
||||
# extract pixel_values and labels from batch
|
||||
images = batch["pixel_values"]
|
||||
targets = batch["labels"]
|
||||
|
||||
# forward pass
|
||||
model_output = self(images, targets=targets)
|
||||
|
||||
# get loss
|
||||
outputs[dataloader_name] = {
|
||||
"loss": sum(model_output.values()),
|
||||
"loss_dict": model_output,
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
outputs = self.common_step(batch, batch_idx)
|
||||
|
||||
# logs metrics for each training_step
|
||||
loss = 0
|
||||
for dataloader_name, output in outputs.items():
|
||||
loss += output["loss"]
|
||||
self.log(f"train_loss_{dataloader_name}", output["loss"])
|
||||
for k, v in output["loss_dict"].items():
|
||||
self.log(f"train_loss_{k}_{dataloader_name}", v.item())
|
||||
|
||||
self.log("lr", self.optimizers().param_groups[0]["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.train()
|
||||
with torch.no_grad():
|
||||
outputs = self.common_step(batch, batch_idx)
|
||||
self.eval()
|
||||
|
||||
# logs metrics for each validation_step
|
||||
loss = 0
|
||||
for dataloader_name, output in outputs.items():
|
||||
loss += output["loss"]
|
||||
self.log(f"val_loss_{dataloader_name}", output["loss"])
|
||||
for k, v in output["loss_dict"].items():
|
||||
self.log(f"val_loss_{k}_{dataloader_name}", v.item())
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""PyTorch optimizers and Schedulers.
|
||||
|
||||
Returns:
|
||||
dictionnary for PyTorch Lightning optimizer/scheduler configuration
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
}
|
2
src/module/__init__.py
Normal file
2
src/module/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from DETR import DETR
|
||||
from FasterRCNN import FasterRCNN
|
Loading…
Reference in a new issue