From 50d18a5b39df85d8f6c28d654163aa2d4794cbad Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Tue, 12 Jul 2022 11:18:03 +0200 Subject: [PATCH] feat: random code I don't want to make commit messages for Former-commit-id: b94db28e25c4ada7f69d65185198a701bb5d6bfd [formerly 2476ee5d84287e40c8fb341f569249dc8aaff3e5] Former-commit-id: 0a4b7a1f925165172b009f8812d3083e70f10201 --- src/data/dataloader.py | 33 +++++++++++++++++---------------- src/data/dataset.py | 33 +++++++++++++++++++++++++++++++++ src/predict.py | 5 ++--- src/utils/callback.py | 14 +++++++------- wandb.yaml | 10 ++++++---- 5 files changed, 65 insertions(+), 30 deletions(-) diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 9e67c4c..f21dbd3 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, Subset import wandb from utils import RandomPaste -from .dataset import LabeledDataset, SyntheticDataset +from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset class Spheres(pl.LightningDataModule): @@ -13,24 +13,26 @@ class Spheres(pl.LightningDataModule): super().__init__() def train_dataloader(self): - transform = A.Compose( - [ - A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), - A.Flip(), - A.ColorJitter(), - RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), - A.GaussianBlur(), - A.ISONoise(), - ], - ) + # transform = A.Compose( + # [ + # A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + # A.Flip(), + # A.ColorJitter(), + # RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), + # A.GaussianBlur(), + # A.ISONoise(), + # ], + # ) - dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) - dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1))) + # dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) + # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1))) + + dataset = LabeledDataset2(image_dir="/home/lilian/data_disk/lfainsin/prerender/") return DataLoader( dataset, shuffle=True, - prefetch_factor=8, + prefetch_factor=wandb.config.PREFETCH_FACTOR, batch_size=wandb.config.TRAIN_BATCH_SIZE, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, @@ -38,13 +40,12 @@ class Spheres(pl.LightningDataModule): def val_dataloader(self): dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) - # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) return DataLoader( dataset, shuffle=False, + prefetch_factor=wandb.config.PREFETCH_FACTOR, batch_size=wandb.config.VAL_BATCH_SIZE, - prefetch_factor=8, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, ) diff --git a/src/data/dataset.py b/src/data/dataset.py index 61508f2..439e284 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -77,3 +77,36 @@ class LabeledDataset(Dataset): mask = mask.float() return image, mask + + +class LabeledDataset2(Dataset): + def __init__(self, image_dir): + self.images = list(Path(image_dir).glob("**/*.jpg")) + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + # open and convert image + image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) + + # open and convert mask + mask_path = self.images[index].parent.joinpath("MASK.PNG") + mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255 + + # convert image & mask to Tensor float in [0, 1] + post_process = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + augmentations = post_process(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] + + # make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only + image = image.float() + mask = mask.float() + + return image, mask diff --git a/src/predict.py b/src/predict.py index f1be604..7dd0ffd 100755 --- a/src/predict.py +++ b/src/predict.py @@ -5,7 +5,6 @@ import albumentations as A import numpy as np import onnx import onnxruntime -import torch from albumentations.pytorch import ToTensorV2 from PIL import Image @@ -58,13 +57,13 @@ if __name__ == "__main__": img = Image.open(args.input).convert("RGB") logging.info(f"Preprocessing image {args.input}") - tf = A.Compose( + transform = A.Compose( [ A.ToFloat(max_value=255), ToTensorV2(), ], ) - aug = tf(image=np.asarray(img)) + aug = transform(image=np.asarray(img)) img = aug["image"] logging.info(f"Predicting image {args.input}") diff --git a/src/utils/callback.py b/src/utils/callback.py index 969b619..15cd5e5 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -67,9 +67,11 @@ class TableLog(Callback): class ArtifactLog(Callback): + def on_fit_start(self, trainer, pl_module): + self.best = 1 + def on_validation_epoch_start(self, trainer, pl_module): self.dices = [] - self.best = 1 def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # unpacking @@ -78,16 +80,14 @@ class ArtifactLog(Callback): def on_validation_epoch_end(self, trainer, pl_module): dice = np.mean(self.dices) - self.dices = [] if dice < self.best: self.best = dice # create checkpoint - torch.save(self.state_dict(), "checkpoints/model.pth") - # trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt + trainer.save_checkpoint("checkpoints/model.ckpt") - # create and log artifact - artifact = wandb.Artifact("pth", type="model") - artifact.add_file("checkpoints/model.pth") + # log artifact + artifact = wandb.Artifact("ckpt", type="model") + artifact.add_file("checkpoints/model.ckpt") wandb.run.log_artifact(artifact) diff --git a/wandb.yaml b/wandb.yaml index 2515950..8bbba18 100644 --- a/wandb.yaml +++ b/wandb.yaml @@ -26,14 +26,16 @@ WORKERS: IMG_SIZE: value: 512 SPHERES: - value: 5 + value: 3 EPOCHS: - value: 10 + value: 20 TRAIN_BATCH_SIZE: - value: 16 + value: 64 # 100 VAL_BATCH_SIZE: - value: 8 + value: 8 # 10 +PREFETCH_FACTOR: + value: 16 LEARNING_RATE: value: 1.0e-4