refactor: making the code actually work

Former-commit-id: 302f77ef109ca44eef8c4ce6c5b3f59ba2891884 [formerly 5ae3d84ea691440283f73b5f12e2c85b7d95b191]
Former-commit-id: 33cbfef93d9be0cefd526e1cdbd7c0435db4c613
This commit is contained in:
Laurent Fainsin 2022-07-10 17:12:00 +02:00
parent 81cbfd6212
commit 8d903e7ad1
6 changed files with 28 additions and 23 deletions

View file

@ -1 +1 @@
from .dataloader import SyntheticSphere
from .dataloader import Spheres

View file

@ -1,6 +1,6 @@
import albumentations as A
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Subset
import wandb
from utils import RandomPaste
@ -8,7 +8,7 @@ from utils import RandomPaste
from .dataset import LabeledDataset, SyntheticDataset
class SyntheticSphere(pl.LightningDataModule):
class Spheres(pl.LightningDataModule):
def __init__(self):
super().__init__()
@ -25,7 +25,7 @@ class SyntheticSphere(pl.LightningDataModule):
)
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
return DataLoader(
dataset,
@ -37,6 +37,7 @@ class SyntheticSphere(pl.LightningDataModule):
def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
return DataLoader(
dataset,

View file

@ -63,7 +63,7 @@ class LabeledDataset(Dataset):
# convert image & mask to Tensor float in [0, 1]
post_process = A.Compose(
[
# A.SmallestMaxSize(1024),
A.SmallestMaxSize(1024),
A.ToFloat(max_value=255),
ToTensorV2(),
],

View file

@ -1,11 +1,12 @@
import logging
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers import WandbLogger
import wandb
from unet import UNet
from data import Spheres
from unet import UNetModule
CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
@ -18,9 +19,9 @@ CONFIG = {
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 10,
"WORKERS": 14,
"EPOCHS": 1,
"BATCH_SIZE": 32,
"BATCH_SIZE": 16 * 3,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
@ -45,7 +46,7 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True)
# Create network
net = UNet(
model = UNetModule(
n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"],
batch_size=CONFIG["BATCH_SIZE"],
@ -54,27 +55,31 @@ if __name__ == "__main__":
)
# log gradients and weights regularly
logger.watch(net, log="all")
logger.watch(model, log="all")
# create checkpoint callback
checkpoint_callback = pl.ModelCheckpoint(
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints",
monitor="val/dice",
)
# Create the dataloaders
datamodule = Spheres()
# Create the trainer
trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"],
# precision=16,
benchmark=CONFIG["BENCHMARK"],
val_check_interval=100,
callbacks=RichProgressBar(),
# profiler="simple",
# precision=16,
logger=logger,
log_every_n_steps=1,
val_check_interval=100,
callbacks=RichProgressBar(),
)
trainer.fit(model=net)
trainer.fit(model=model, datamodule=datamodule)
# stop wandb
wandb.run.finish()

View file

@ -29,8 +29,6 @@ class UNet(nn.Module):
def forward(self, x):
skips = []
x = x.to(self.device)
x = self.inc(x)
for down in self.downs:

View file

@ -73,7 +73,7 @@ class UNetModule(pl.LightningModule):
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary
tensors["images"].cpu(),
tensors["data"].cpu(),
tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(),
tensors["binary"]
@ -121,7 +121,7 @@ class UNetModule(pl.LightningModule):
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
for i, (img, mask, pred, pred_bin) in enumerate(
zip( # TODO: use comprehension list to zip the dictionnary
tensors["images"].cpu(),
tensors["data"].cpu(),
tensors["ground_truth"].cpu(),
tensors["prediction"].cpu(),
tensors["binary"]
@ -150,11 +150,12 @@ class UNetModule(pl.LightningModule):
]
)
return metrics
return metrics, rows
def validation_epoch_end(self, validation_outputs):
# unpacking
metricss, rowss = validation_outputs
metricss = [v[0] for v in validation_outputs]
rowss = [v[1] for v in validation_outputs]
# metrics flattening
metrics = {
@ -162,7 +163,7 @@ class UNetModule(pl.LightningModule):
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(),
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
}
# log metrics