refactor: making the code actually work

Former-commit-id: 302f77ef109ca44eef8c4ce6c5b3f59ba2891884 [formerly 5ae3d84ea691440283f73b5f12e2c85b7d95b191]
Former-commit-id: 33cbfef93d9be0cefd526e1cdbd7c0435db4c613
This commit is contained in:
Laurent Fainsin 2022-07-10 17:12:00 +02:00
parent 81cbfd6212
commit 8d903e7ad1
6 changed files with 28 additions and 23 deletions

View file

@ -1 +1 @@
from .dataloader import SyntheticSphere from .dataloader import Spheres

View file

@ -1,6 +1,6 @@
import albumentations as A import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Subset
import wandb import wandb
from utils import RandomPaste from utils import RandomPaste
@ -8,7 +8,7 @@ from utils import RandomPaste
from .dataset import LabeledDataset, SyntheticDataset from .dataset import LabeledDataset, SyntheticDataset
class SyntheticSphere(pl.LightningDataModule): class Spheres(pl.LightningDataModule):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -25,7 +25,7 @@ class SyntheticSphere(pl.LightningDataModule):
) )
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
return DataLoader( return DataLoader(
dataset, dataset,
@ -37,6 +37,7 @@ class SyntheticSphere(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
return DataLoader( return DataLoader(
dataset, dataset,

View file

@ -63,7 +63,7 @@ class LabeledDataset(Dataset):
# convert image & mask to Tensor float in [0, 1] # convert image & mask to Tensor float in [0, 1]
post_process = A.Compose( post_process = A.Compose(
[ [
# A.SmallestMaxSize(1024), A.SmallestMaxSize(1024),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
], ],

View file

@ -1,11 +1,12 @@
import logging import logging
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
import wandb import wandb
from unet import UNet from data import Spheres
from unet import UNetModule
CONFIG = { CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
@ -18,9 +19,9 @@ CONFIG = {
"PIN_MEMORY": True, "PIN_MEMORY": True,
"BENCHMARK": True, "BENCHMARK": True,
"DEVICE": "gpu", "DEVICE": "gpu",
"WORKERS": 10, "WORKERS": 14,
"EPOCHS": 1, "EPOCHS": 1,
"BATCH_SIZE": 32, "BATCH_SIZE": 16 * 3,
"LEARNING_RATE": 1e-4, "LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8, "WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9, "MOMENTUM": 0.9,
@ -45,7 +46,7 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True) pl.seed_everything(69420, workers=True)
# Create network # Create network
net = UNet( model = UNetModule(
n_channels=CONFIG["N_CHANNELS"], n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"], n_classes=CONFIG["N_CLASSES"],
batch_size=CONFIG["BATCH_SIZE"], batch_size=CONFIG["BATCH_SIZE"],
@ -54,27 +55,31 @@ if __name__ == "__main__":
) )
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(net, log="all") logger.watch(model, log="all")
# create checkpoint callback # create checkpoint callback
checkpoint_callback = pl.ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints", dirpath="checkpoints",
monitor="val/dice", monitor="val/dice",
) )
# Create the dataloaders
datamodule = Spheres()
# Create the trainer # Create the trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"], max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"], accelerator=CONFIG["DEVICE"],
# precision=16,
benchmark=CONFIG["BENCHMARK"], benchmark=CONFIG["BENCHMARK"],
val_check_interval=100, # profiler="simple",
callbacks=RichProgressBar(), # precision=16,
logger=logger, logger=logger,
log_every_n_steps=1, log_every_n_steps=1,
val_check_interval=100,
callbacks=RichProgressBar(),
) )
trainer.fit(model=net) trainer.fit(model=model, datamodule=datamodule)
# stop wandb # stop wandb
wandb.run.finish() wandb.run.finish()

View file

@ -29,8 +29,6 @@ class UNet(nn.Module):
def forward(self, x): def forward(self, x):
skips = [] skips = []
x = x.to(self.device)
x = self.inc(x) x = self.inc(x)
for down in self.downs: for down in self.downs:

View file

@ -73,7 +73,7 @@ class UNetModule(pl.LightningModule):
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary zip( # TODO: use comprehension list to zip the dictionnary
tensors["images"].cpu(), tensors["data"].cpu(),
tensors["ground_truth"].cpu(), tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(), tensors["prediction"].cpu(),
tensors["binary"] tensors["binary"]
@ -121,7 +121,7 @@ class UNetModule(pl.LightningModule):
if batch_idx % 50 == 0 or metrics["dice"] > 0.9: if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary zip( # TODO: use comprehension list to zip the dictionnary
tensors["images"].cpu(), tensors["data"].cpu(),
tensors["ground_truth"].cpu(), tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(), tensors["prediction"].cpu(),
tensors["binary"] tensors["binary"]
@ -150,11 +150,12 @@ class UNetModule(pl.LightningModule):
] ]
) )
return metrics return metrics, rows
def validation_epoch_end(self, validation_outputs): def validation_epoch_end(self, validation_outputs):
# unpacking # unpacking
metricss, rowss = validation_outputs metricss = [v[0] for v in validation_outputs]
rowss = [v[1] for v in validation_outputs]
# metrics flattening # metrics flattening
metrics = { metrics = {
@ -162,7 +163,7 @@ class UNetModule(pl.LightningModule):
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(), "dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
"bce": torch.stack([d["bce"] for d in metricss]).mean(), "bce": torch.stack([d["bce"] for d in metricss]).mean(),
"mae": torch.stack([d["mae"] for d in metricss]).mean(), "mae": torch.stack([d["mae"] for d in metricss]).mean(),
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(), "accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
} }
# log metrics # log metrics