mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
feat: more docstrings and typing
Former-commit-id: 9082187ec9d66e93c0195374022290cb9231be00 [formerly c55aabb212241973372630e9a078da5fc0342abf] Former-commit-id: ee4dbe5392b99e03db27da2873c180c84a75737f
This commit is contained in:
parent
be37d2706f
commit
1ed1bb185c
|
@ -1,8 +1,10 @@
|
||||||
|
"""Pytorch Lightning DataModules."""
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
import wandb
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch.utils.data import DataLoader, Subset
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .dataset import LabeledDataset, RealDataset
|
from .dataset import LabeledDataset, RealDataset
|
||||||
|
|
||||||
|
@ -12,10 +14,14 @@ def collate_fn(batch):
|
||||||
|
|
||||||
|
|
||||||
class Spheres(pl.LightningDataModule):
|
class Spheres(pl.LightningDataModule):
|
||||||
def __init__(self):
|
"""Pytorch Lightning DataModule, encapsulating common PyTorch functions."""
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self) -> DataLoader:
|
||||||
|
"""PyTorch training Dataloader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader: the training dataloader
|
||||||
|
"""
|
||||||
transforms = A.Compose(
|
transforms = A.Compose(
|
||||||
[
|
[
|
||||||
# A.Flip(),
|
# A.Flip(),
|
||||||
|
@ -40,7 +46,7 @@ class Spheres(pl.LightningDataModule):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LabeledDataset("/dev/shm/TRAIN/", transforms)
|
dataset = LabeledDataset(image_dir="/dev/shm/TRAIN/", transforms=transforms)
|
||||||
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
|
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
|
||||||
# dataset = Subset(dataset, [0] * 320) # overfit test
|
# dataset = Subset(dataset, [0] * 320) # overfit test
|
||||||
|
|
||||||
|
@ -55,7 +61,12 @@ class Spheres(pl.LightningDataModule):
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self) -> DataLoader:
|
||||||
|
"""PyTorch validation Dataloader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader: the validation dataloader
|
||||||
|
"""
|
||||||
transforms = A.Compose(
|
transforms = A.Compose(
|
||||||
[
|
[
|
||||||
A.Normalize(
|
A.Normalize(
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Pytorch Datasets."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -9,14 +11,14 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataset(Dataset):
|
class SyntheticDataset(Dataset):
|
||||||
def __init__(self, image_dir, transform):
|
def __init__(self, image_dir: str, transform: A.Compose) -> None:
|
||||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.images)
|
return len(self.images)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index: int):
|
||||||
# open and convert image
|
# open and convert image
|
||||||
image = np.ascontiguousarray(
|
image = np.ascontiguousarray(
|
||||||
Image.open(
|
Image.open(
|
||||||
|
@ -40,7 +42,7 @@ class SyntheticDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
class RealDataset(Dataset):
|
class RealDataset(Dataset):
|
||||||
def __init__(self, root, transforms=None):
|
def __init__(self, root, transforms=None) -> None:
|
||||||
self.root = root
|
self.root = root
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
|
|
||||||
|
@ -50,7 +52,10 @@ class RealDataset(Dataset):
|
||||||
|
|
||||||
self.res = A.LongestMaxSize(max_size=1024)
|
self.res = A.LongestMaxSize(max_size=1024)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __len__(self) -> int:
|
||||||
|
return len(self.imgs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
# create paths from ids
|
# create paths from ids
|
||||||
image_path = os.path.join(self.root, "images", self.imgs[idx])
|
image_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||||
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
||||||
|
@ -127,19 +132,16 @@ class RealDataset(Dataset):
|
||||||
|
|
||||||
return image, target
|
return image, target
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.imgs)
|
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
def __init__(self, image_dir, transforms):
|
def __init__(self, image_dir, transforms) -> None:
|
||||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.images)
|
return len(self.images)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx: int):
|
||||||
# open and convert image
|
# open and convert image
|
||||||
image = np.ascontiguousarray(
|
image = np.ascontiguousarray(
|
||||||
Image.open(self.images[idx]).convert("RGB"),
|
Image.open(self.images[idx]).convert("RGB"),
|
||||||
|
|
|
@ -47,7 +47,7 @@ def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
|
||||||
|
|
||||||
|
|
||||||
class MRCNNModule(pl.LightningModule):
|
class MRCNNModule(pl.LightningModule):
|
||||||
"""Mask R-CNN Pytorch Lightning Module encapsulating commong PyTorch functions."""
|
"""Mask R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
|
||||||
|
|
||||||
def __init__(self, n_classes: int) -> None:
|
def __init__(self, n_classes: int) -> None:
|
||||||
"""Constructor, build model, save hyperparameters.
|
"""Constructor, build model, save hyperparameters.
|
||||||
|
|
12
src/train.py
12
src/train.py
|
@ -1,4 +1,4 @@
|
||||||
import logging
|
"""Main script, to be launched to start the fine tuning of the neural network."""
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
import wandb
|
||||||
|
@ -6,24 +6,16 @@ from pytorch_lightning.callbacks import (
|
||||||
EarlyStopping,
|
EarlyStopping,
|
||||||
LearningRateMonitor,
|
LearningRateMonitor,
|
||||||
ModelCheckpoint,
|
ModelCheckpoint,
|
||||||
ModelPruning,
|
|
||||||
QuantizationAwareTraining,
|
|
||||||
RichModelSummary,
|
RichModelSummary,
|
||||||
RichProgressBar,
|
RichProgressBar,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
from data import Spheres
|
from data import Spheres
|
||||||
from mrcnn import MRCNNModule
|
from modules import MRCNNModule
|
||||||
from utils.callback import TableLog
|
from utils.callback import TableLog
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(levelname)s: %(message)s",
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
project="Mask R-CNN",
|
project="Mask R-CNN",
|
||||||
|
|
Loading…
Reference in a new issue