From 1ed1bb185c0314a259c76cdca89ca00a822b737a Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Mon, 12 Sep 2022 11:45:19 +0200 Subject: [PATCH] feat: more docstrings and typing Former-commit-id: 9082187ec9d66e93c0195374022290cb9231be00 [formerly c55aabb212241973372630e9a078da5fc0342abf] Former-commit-id: ee4dbe5392b99e03db27da2873c180c84a75737f --- src/data/dataloader.py | 23 +++++++++++++++++------ src/data/dataset.py | 24 +++++++++++++----------- src/modules/mrcnn.py | 2 +- src/train.py | 12 ++---------- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 120ac10..c1f444a 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -1,8 +1,10 @@ +"""Pytorch Lightning DataModules.""" + import albumentations as A import pytorch_lightning as pl import wandb from albumentations.pytorch import ToTensorV2 -from torch.utils.data import DataLoader, Subset +from torch.utils.data import DataLoader from .dataset import LabeledDataset, RealDataset @@ -12,10 +14,14 @@ def collate_fn(batch): class Spheres(pl.LightningDataModule): - def __init__(self): - super().__init__() + """Pytorch Lightning DataModule, encapsulating common PyTorch functions.""" - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """PyTorch training Dataloader. + + Returns: + DataLoader: the training dataloader + """ transforms = A.Compose( [ # A.Flip(), @@ -40,7 +46,7 @@ class Spheres(pl.LightningDataModule): ), ) - dataset = LabeledDataset("/dev/shm/TRAIN/", transforms) + dataset = LabeledDataset(image_dir="/dev/shm/TRAIN/", transforms=transforms) # dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose # dataset = Subset(dataset, [0] * 320) # overfit test @@ -55,7 +61,12 @@ class Spheres(pl.LightningDataModule): collate_fn=collate_fn, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """PyTorch validation Dataloader. + + Returns: + DataLoader: the validation dataloader + """ transforms = A.Compose( [ A.Normalize( diff --git a/src/data/dataset.py b/src/data/dataset.py index 3addc8b..f0e6a82 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -1,3 +1,5 @@ +"""Pytorch Datasets.""" + import os from pathlib import Path @@ -9,14 +11,14 @@ from torch.utils.data import Dataset class SyntheticDataset(Dataset): - def __init__(self, image_dir, transform): + def __init__(self, image_dir: str, transform: A.Compose) -> None: self.images = list(Path(image_dir).glob("**/*.jpg")) self.transform = transform - def __len__(self): + def __len__(self) -> int: return len(self.images) - def __getitem__(self, index): + def __getitem__(self, index: int): # open and convert image image = np.ascontiguousarray( Image.open( @@ -40,7 +42,7 @@ class SyntheticDataset(Dataset): class RealDataset(Dataset): - def __init__(self, root, transforms=None): + def __init__(self, root, transforms=None) -> None: self.root = root self.transforms = transforms @@ -50,7 +52,10 @@ class RealDataset(Dataset): self.res = A.LongestMaxSize(max_size=1024) - def __getitem__(self, idx): + def __len__(self) -> int: + return len(self.imgs) + + def __getitem__(self, idx: int): # create paths from ids image_path = os.path.join(self.root, "images", self.imgs[idx]) mask_path = os.path.join(self.root, "masks", self.masks[idx]) @@ -127,19 +132,16 @@ class RealDataset(Dataset): return image, target - def __len__(self): - return len(self.imgs) - class LabeledDataset(Dataset): - def __init__(self, image_dir, transforms): + def __init__(self, image_dir, transforms) -> None: self.images = list(Path(image_dir).glob("**/*.jpg")) self.transforms = transforms - def __len__(self): + def __len__(self) -> int: return len(self.images) - def __getitem__(self, idx): + def __getitem__(self, idx: int): # open and convert image image = np.ascontiguousarray( Image.open(self.images[idx]).convert("RGB"), diff --git a/src/modules/mrcnn.py b/src/modules/mrcnn.py index da19a1e..c37c96d 100644 --- a/src/modules/mrcnn.py +++ b/src/modules/mrcnn.py @@ -47,7 +47,7 @@ def get_model_instance_segmentation(n_classes: int) -> MaskRCNN: class MRCNNModule(pl.LightningModule): - """Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions.""" + """Mask R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions.""" def __init__(self, n_classes: int) -> None: """Constructor, build model, save hyperparameters. diff --git a/src/train.py b/src/train.py index eabaf31..98f0b13 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,4 @@ -import logging +"""Main script, to be launched to start the fine tuning of the neural network.""" import pytorch_lightning as pl import wandb @@ -6,24 +6,16 @@ from pytorch_lightning.callbacks import ( EarlyStopping, LearningRateMonitor, ModelCheckpoint, - ModelPruning, - QuantizationAwareTraining, RichModelSummary, RichProgressBar, ) from pytorch_lightning.loggers import WandbLogger from data import Spheres -from mrcnn import MRCNNModule +from modules import MRCNNModule from utils.callback import TableLog if __name__ == "__main__": - # setup logging - logging.basicConfig( - level=logging.INFO, - format="%(levelname)s: %(message)s", - ) - # setup wandb logger = WandbLogger( project="Mask R-CNN",