feat: more docstrings and typing

Former-commit-id: 9082187ec9d66e93c0195374022290cb9231be00 [formerly c55aabb212241973372630e9a078da5fc0342abf]
Former-commit-id: ee4dbe5392b99e03db27da2873c180c84a75737f
This commit is contained in:
Laurent Fainsin 2022-09-12 11:45:19 +02:00
parent be37d2706f
commit 1ed1bb185c
4 changed files with 33 additions and 28 deletions

View file

@ -1,8 +1,10 @@
"""Pytorch Lightning DataModules."""
import albumentations as A
import pytorch_lightning as pl
import wandb
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Subset
from torch.utils.data import DataLoader
from .dataset import LabeledDataset, RealDataset
@ -12,10 +14,14 @@ def collate_fn(batch):
class Spheres(pl.LightningDataModule):
def __init__(self):
super().__init__()
"""Pytorch Lightning DataModule, encapsulating common PyTorch functions."""
def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
"""PyTorch training Dataloader.
Returns:
DataLoader: the training dataloader
"""
transforms = A.Compose(
[
# A.Flip(),
@ -40,7 +46,7 @@ class Spheres(pl.LightningDataModule):
),
)
dataset = LabeledDataset("/dev/shm/TRAIN/", transforms)
dataset = LabeledDataset(image_dir="/dev/shm/TRAIN/", transforms=transforms)
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
# dataset = Subset(dataset, [0] * 320) # overfit test
@ -55,7 +61,12 @@ class Spheres(pl.LightningDataModule):
collate_fn=collate_fn,
)
def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
"""PyTorch validation Dataloader.
Returns:
DataLoader: the validation dataloader
"""
transforms = A.Compose(
[
A.Normalize(

View file

@ -1,3 +1,5 @@
"""Pytorch Datasets."""
import os
from pathlib import Path
@ -9,14 +11,14 @@ from torch.utils.data import Dataset
class SyntheticDataset(Dataset):
def __init__(self, image_dir, transform):
def __init__(self, image_dir: str, transform: A.Compose) -> None:
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transform = transform
def __len__(self):
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, index):
def __getitem__(self, index: int):
# open and convert image
image = np.ascontiguousarray(
Image.open(
@ -40,7 +42,7 @@ class SyntheticDataset(Dataset):
class RealDataset(Dataset):
def __init__(self, root, transforms=None):
def __init__(self, root, transforms=None) -> None:
self.root = root
self.transforms = transforms
@ -50,7 +52,10 @@ class RealDataset(Dataset):
self.res = A.LongestMaxSize(max_size=1024)
def __getitem__(self, idx):
def __len__(self) -> int:
return len(self.imgs)
def __getitem__(self, idx: int):
# create paths from ids
image_path = os.path.join(self.root, "images", self.imgs[idx])
mask_path = os.path.join(self.root, "masks", self.masks[idx])
@ -127,19 +132,16 @@ class RealDataset(Dataset):
return image, target
def __len__(self):
return len(self.imgs)
class LabeledDataset(Dataset):
def __init__(self, image_dir, transforms):
def __init__(self, image_dir, transforms) -> None:
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transforms = transforms
def __len__(self):
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx):
def __getitem__(self, idx: int):
# open and convert image
image = np.ascontiguousarray(
Image.open(self.images[idx]).convert("RGB"),

View file

@ -47,7 +47,7 @@ def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
class MRCNNModule(pl.LightningModule):
"""Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions."""
"""Mask R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
def __init__(self, n_classes: int) -> None:
"""Constructor, build model, save hyperparameters.

View file

@ -1,4 +1,4 @@
import logging
"""Main script, to be launched to start the fine tuning of the neural network."""
import pytorch_lightning as pl
import wandb
@ -6,24 +6,16 @@ from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
ModelPruning,
QuantizationAwareTraining,
RichModelSummary,
RichProgressBar,
)
from pytorch_lightning.loggers import WandbLogger
from data import Spheres
from mrcnn import MRCNNModule
from modules import MRCNNModule
from utils.callback import TableLog
if __name__ == "__main__":
# setup logging
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
# setup wandb
logger = WandbLogger(
project="Mask R-CNN",