feat: more docstrings and typing

Former-commit-id: 9082187ec9d66e93c0195374022290cb9231be00 [formerly c55aabb212241973372630e9a078da5fc0342abf]
Former-commit-id: ee4dbe5392b99e03db27da2873c180c84a75737f
This commit is contained in:
Laurent Fainsin 2022-09-12 11:45:19 +02:00
parent be37d2706f
commit 1ed1bb185c
4 changed files with 33 additions and 28 deletions

View file

@ -1,8 +1,10 @@
"""Pytorch Lightning DataModules."""
import albumentations as A import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb import wandb
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Subset from torch.utils.data import DataLoader
from .dataset import LabeledDataset, RealDataset from .dataset import LabeledDataset, RealDataset
@ -12,10 +14,14 @@ def collate_fn(batch):
class Spheres(pl.LightningDataModule): class Spheres(pl.LightningDataModule):
def __init__(self): """Pytorch Lightning DataModule, encapsulating common PyTorch functions."""
super().__init__()
def train_dataloader(self): def train_dataloader(self) -> DataLoader:
"""PyTorch training Dataloader.
Returns:
DataLoader: the training dataloader
"""
transforms = A.Compose( transforms = A.Compose(
[ [
# A.Flip(), # A.Flip(),
@ -40,7 +46,7 @@ class Spheres(pl.LightningDataModule):
), ),
) )
dataset = LabeledDataset("/dev/shm/TRAIN/", transforms) dataset = LabeledDataset(image_dir="/dev/shm/TRAIN/", transforms=transforms)
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose # dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
# dataset = Subset(dataset, [0] * 320) # overfit test # dataset = Subset(dataset, [0] * 320) # overfit test
@ -55,7 +61,12 @@ class Spheres(pl.LightningDataModule):
collate_fn=collate_fn, collate_fn=collate_fn,
) )
def val_dataloader(self): def val_dataloader(self) -> DataLoader:
"""PyTorch validation Dataloader.
Returns:
DataLoader: the validation dataloader
"""
transforms = A.Compose( transforms = A.Compose(
[ [
A.Normalize( A.Normalize(

View file

@ -1,3 +1,5 @@
"""Pytorch Datasets."""
import os import os
from pathlib import Path from pathlib import Path
@ -9,14 +11,14 @@ from torch.utils.data import Dataset
class SyntheticDataset(Dataset): class SyntheticDataset(Dataset):
def __init__(self, image_dir, transform): def __init__(self, image_dir: str, transform: A.Compose) -> None:
self.images = list(Path(image_dir).glob("**/*.jpg")) self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transform = transform self.transform = transform
def __len__(self): def __len__(self) -> int:
return len(self.images) return len(self.images)
def __getitem__(self, index): def __getitem__(self, index: int):
# open and convert image # open and convert image
image = np.ascontiguousarray( image = np.ascontiguousarray(
Image.open( Image.open(
@ -40,7 +42,7 @@ class SyntheticDataset(Dataset):
class RealDataset(Dataset): class RealDataset(Dataset):
def __init__(self, root, transforms=None): def __init__(self, root, transforms=None) -> None:
self.root = root self.root = root
self.transforms = transforms self.transforms = transforms
@ -50,7 +52,10 @@ class RealDataset(Dataset):
self.res = A.LongestMaxSize(max_size=1024) self.res = A.LongestMaxSize(max_size=1024)
def __getitem__(self, idx): def __len__(self) -> int:
return len(self.imgs)
def __getitem__(self, idx: int):
# create paths from ids # create paths from ids
image_path = os.path.join(self.root, "images", self.imgs[idx]) image_path = os.path.join(self.root, "images", self.imgs[idx])
mask_path = os.path.join(self.root, "masks", self.masks[idx]) mask_path = os.path.join(self.root, "masks", self.masks[idx])
@ -127,19 +132,16 @@ class RealDataset(Dataset):
return image, target return image, target
def __len__(self):
return len(self.imgs)
class LabeledDataset(Dataset): class LabeledDataset(Dataset):
def __init__(self, image_dir, transforms): def __init__(self, image_dir, transforms) -> None:
self.images = list(Path(image_dir).glob("**/*.jpg")) self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transforms = transforms self.transforms = transforms
def __len__(self): def __len__(self) -> int:
return len(self.images) return len(self.images)
def __getitem__(self, idx): def __getitem__(self, idx: int):
# open and convert image # open and convert image
image = np.ascontiguousarray( image = np.ascontiguousarray(
Image.open(self.images[idx]).convert("RGB"), Image.open(self.images[idx]).convert("RGB"),

View file

@ -47,7 +47,7 @@ def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
class MRCNNModule(pl.LightningModule): class MRCNNModule(pl.LightningModule):
"""Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions.""" """Mask R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
def __init__(self, n_classes: int) -> None: def __init__(self, n_classes: int) -> None:
"""Constructor, build model, save hyperparameters. """Constructor, build model, save hyperparameters.

View file

@ -1,4 +1,4 @@
import logging """Main script, to be launched to start the fine tuning of the neural network."""
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb import wandb
@ -6,24 +6,16 @@ from pytorch_lightning.callbacks import (
EarlyStopping, EarlyStopping,
LearningRateMonitor, LearningRateMonitor,
ModelCheckpoint, ModelCheckpoint,
ModelPruning,
QuantizationAwareTraining,
RichModelSummary, RichModelSummary,
RichProgressBar, RichProgressBar,
) )
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from data import Spheres from data import Spheres
from mrcnn import MRCNNModule from modules import MRCNNModule
from utils.callback import TableLog from utils.callback import TableLog
if __name__ == "__main__": if __name__ == "__main__":
# setup logging
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
# setup wandb # setup wandb
logger = WandbLogger( logger = WandbLogger(
project="Mask R-CNN", project="Mask R-CNN",