feat: random code I don't want to make commit messages for

Former-commit-id: b94db28e25c4ada7f69d65185198a701bb5d6bfd [formerly 2476ee5d84287e40c8fb341f569249dc8aaff3e5]
Former-commit-id: 0a4b7a1f925165172b009f8812d3083e70f10201
This commit is contained in:
Laurent Fainsin 2022-07-12 11:18:03 +02:00
parent 9b7418facd
commit 50d18a5b39
5 changed files with 65 additions and 30 deletions

View file

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, Subset
import wandb import wandb
from utils import RandomPaste from utils import RandomPaste
from .dataset import LabeledDataset, SyntheticDataset from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
class Spheres(pl.LightningDataModule): class Spheres(pl.LightningDataModule):
@ -13,24 +13,26 @@ class Spheres(pl.LightningDataModule):
super().__init__() super().__init__()
def train_dataloader(self): def train_dataloader(self):
transform = A.Compose( # transform = A.Compose(
[ # [
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), # A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(), # A.Flip(),
A.ColorJitter(), # A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE), # RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
A.GaussianBlur(), # A.GaussianBlur(),
A.ISONoise(), # A.ISONoise(),
], # ],
) # )
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) # dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1))) # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
dataset = LabeledDataset2(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
return DataLoader( return DataLoader(
dataset, dataset,
shuffle=True, shuffle=True,
prefetch_factor=8, prefetch_factor=wandb.config.PREFETCH_FACTOR,
batch_size=wandb.config.TRAIN_BATCH_SIZE, batch_size=wandb.config.TRAIN_BATCH_SIZE,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
@ -38,13 +40,12 @@ class Spheres(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
return DataLoader( return DataLoader(
dataset, dataset,
shuffle=False, shuffle=False,
prefetch_factor=wandb.config.PREFETCH_FACTOR,
batch_size=wandb.config.VAL_BATCH_SIZE, batch_size=wandb.config.VAL_BATCH_SIZE,
prefetch_factor=8,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
) )

View file

@ -77,3 +77,36 @@ class LabeledDataset(Dataset):
mask = mask.float() mask = mask.float()
return image, mask return image, mask
class LabeledDataset2(Dataset):
def __init__(self, image_dir):
self.images = list(Path(image_dir).glob("**/*.jpg"))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
# open and convert image
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
# open and convert mask
mask_path = self.images[index].parent.joinpath("MASK.PNG")
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255
# convert image & mask to Tensor float in [0, 1]
post_process = A.Compose(
[
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
augmentations = post_process(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
image = image.float()
mask = mask.float()
return image, mask

View file

@ -5,7 +5,6 @@ import albumentations as A
import numpy as np import numpy as np
import onnx import onnx
import onnxruntime import onnxruntime
import torch
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from PIL import Image from PIL import Image
@ -58,13 +57,13 @@ if __name__ == "__main__":
img = Image.open(args.input).convert("RGB") img = Image.open(args.input).convert("RGB")
logging.info(f"Preprocessing image {args.input}") logging.info(f"Preprocessing image {args.input}")
tf = A.Compose( transform = A.Compose(
[ [
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
], ],
) )
aug = tf(image=np.asarray(img)) aug = transform(image=np.asarray(img))
img = aug["image"] img = aug["image"]
logging.info(f"Predicting image {args.input}") logging.info(f"Predicting image {args.input}")

View file

@ -67,9 +67,11 @@ class TableLog(Callback):
class ArtifactLog(Callback): class ArtifactLog(Callback):
def on_fit_start(self, trainer, pl_module):
self.best = 1
def on_validation_epoch_start(self, trainer, pl_module): def on_validation_epoch_start(self, trainer, pl_module):
self.dices = [] self.dices = []
self.best = 1
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# unpacking # unpacking
@ -78,16 +80,14 @@ class ArtifactLog(Callback):
def on_validation_epoch_end(self, trainer, pl_module): def on_validation_epoch_end(self, trainer, pl_module):
dice = np.mean(self.dices) dice = np.mean(self.dices)
self.dices = []
if dice < self.best: if dice < self.best:
self.best = dice self.best = dice
# create checkpoint # create checkpoint
torch.save(self.state_dict(), "checkpoints/model.pth") trainer.save_checkpoint("checkpoints/model.ckpt")
# trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt
# create and log artifact # log artifact
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("ckpt", type="model")
artifact.add_file("checkpoints/model.pth") artifact.add_file("checkpoints/model.ckpt")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)

View file

@ -26,14 +26,16 @@ WORKERS:
IMG_SIZE: IMG_SIZE:
value: 512 value: 512
SPHERES: SPHERES:
value: 5 value: 3
EPOCHS: EPOCHS:
value: 10 value: 20
TRAIN_BATCH_SIZE: TRAIN_BATCH_SIZE:
value: 16 value: 64 # 100
VAL_BATCH_SIZE: VAL_BATCH_SIZE:
value: 8 value: 8 # 10
PREFETCH_FACTOR:
value: 16
LEARNING_RATE: LEARNING_RATE:
value: 1.0e-4 value: 1.0e-4