mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
feat: random code I don't want to make commit messages for
Former-commit-id: b94db28e25c4ada7f69d65185198a701bb5d6bfd [formerly 2476ee5d84287e40c8fb341f569249dc8aaff3e5] Former-commit-id: 0a4b7a1f925165172b009f8812d3083e70f10201
This commit is contained in:
parent
9b7418facd
commit
50d18a5b39
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader, Subset
|
||||||
import wandb
|
import wandb
|
||||||
from utils import RandomPaste
|
from utils import RandomPaste
|
||||||
|
|
||||||
from .dataset import LabeledDataset, SyntheticDataset
|
from .dataset import LabeledDataset, LabeledDataset2, SyntheticDataset
|
||||||
|
|
||||||
|
|
||||||
class Spheres(pl.LightningDataModule):
|
class Spheres(pl.LightningDataModule):
|
||||||
|
@ -13,24 +13,26 @@ class Spheres(pl.LightningDataModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
transform = A.Compose(
|
# transform = A.Compose(
|
||||||
[
|
# [
|
||||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
# A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||||
A.Flip(),
|
# A.Flip(),
|
||||||
A.ColorJitter(),
|
# A.ColorJitter(),
|
||||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
# RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||||
A.GaussianBlur(),
|
# A.GaussianBlur(),
|
||||||
A.ISONoise(),
|
# A.ISONoise(),
|
||||||
],
|
# ],
|
||||||
)
|
# )
|
||||||
|
|
||||||
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
# dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
|
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
|
||||||
|
|
||||||
|
dataset = LabeledDataset2(image_dir="/home/lilian/data_disk/lfainsin/prerender/")
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
prefetch_factor=8,
|
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
@ -38,13 +40,12 @@ class Spheres(pl.LightningDataModule):
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||||
batch_size=wandb.config.VAL_BATCH_SIZE,
|
batch_size=wandb.config.VAL_BATCH_SIZE,
|
||||||
prefetch_factor=8,
|
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
|
@ -77,3 +77,36 @@ class LabeledDataset(Dataset):
|
||||||
mask = mask.float()
|
mask = mask.float()
|
||||||
|
|
||||||
return image, mask
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledDataset2(Dataset):
|
||||||
|
def __init__(self, image_dir):
|
||||||
|
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# open and convert image
|
||||||
|
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||||
|
|
||||||
|
# open and convert mask
|
||||||
|
mask_path = self.images[index].parent.joinpath("MASK.PNG")
|
||||||
|
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255
|
||||||
|
|
||||||
|
# convert image & mask to Tensor float in [0, 1]
|
||||||
|
post_process = A.Compose(
|
||||||
|
[
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
augmentations = post_process(image=image, mask=mask)
|
||||||
|
image = augmentations["image"]
|
||||||
|
mask = augmentations["mask"]
|
||||||
|
|
||||||
|
# make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only
|
||||||
|
image = image.float()
|
||||||
|
mask = mask.float()
|
||||||
|
|
||||||
|
return image, mask
|
||||||
|
|
|
@ -5,7 +5,6 @@ import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import torch
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -58,13 +57,13 @@ if __name__ == "__main__":
|
||||||
img = Image.open(args.input).convert("RGB")
|
img = Image.open(args.input).convert("RGB")
|
||||||
|
|
||||||
logging.info(f"Preprocessing image {args.input}")
|
logging.info(f"Preprocessing image {args.input}")
|
||||||
tf = A.Compose(
|
transform = A.Compose(
|
||||||
[
|
[
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
aug = tf(image=np.asarray(img))
|
aug = transform(image=np.asarray(img))
|
||||||
img = aug["image"]
|
img = aug["image"]
|
||||||
|
|
||||||
logging.info(f"Predicting image {args.input}")
|
logging.info(f"Predicting image {args.input}")
|
||||||
|
|
|
@ -67,9 +67,11 @@ class TableLog(Callback):
|
||||||
|
|
||||||
|
|
||||||
class ArtifactLog(Callback):
|
class ArtifactLog(Callback):
|
||||||
|
def on_fit_start(self, trainer, pl_module):
|
||||||
|
self.best = 1
|
||||||
|
|
||||||
def on_validation_epoch_start(self, trainer, pl_module):
|
def on_validation_epoch_start(self, trainer, pl_module):
|
||||||
self.dices = []
|
self.dices = []
|
||||||
self.best = 1
|
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
|
@ -78,16 +80,14 @@ class ArtifactLog(Callback):
|
||||||
|
|
||||||
def on_validation_epoch_end(self, trainer, pl_module):
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
dice = np.mean(self.dices)
|
dice = np.mean(self.dices)
|
||||||
self.dices = []
|
|
||||||
|
|
||||||
if dice < self.best:
|
if dice < self.best:
|
||||||
self.best = dice
|
self.best = dice
|
||||||
|
|
||||||
# create checkpoint
|
# create checkpoint
|
||||||
torch.save(self.state_dict(), "checkpoints/model.pth")
|
trainer.save_checkpoint("checkpoints/model.ckpt")
|
||||||
# trainer.save_checkpoint("example.ckpt") # TODO: change to .ckpt
|
|
||||||
|
|
||||||
# create and log artifact
|
# log artifact
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("ckpt", type="model")
|
||||||
artifact.add_file("checkpoints/model.pth")
|
artifact.add_file("checkpoints/model.ckpt")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
10
wandb.yaml
10
wandb.yaml
|
@ -26,14 +26,16 @@ WORKERS:
|
||||||
IMG_SIZE:
|
IMG_SIZE:
|
||||||
value: 512
|
value: 512
|
||||||
SPHERES:
|
SPHERES:
|
||||||
value: 5
|
value: 3
|
||||||
|
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 10
|
value: 20
|
||||||
TRAIN_BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 16
|
value: 64 # 100
|
||||||
VAL_BATCH_SIZE:
|
VAL_BATCH_SIZE:
|
||||||
value: 8
|
value: 8 # 10
|
||||||
|
PREFETCH_FACTOR:
|
||||||
|
value: 16
|
||||||
|
|
||||||
LEARNING_RATE:
|
LEARNING_RATE:
|
||||||
value: 1.0e-4
|
value: 1.0e-4
|
||||||
|
|
Loading…
Reference in a new issue