diff --git a/src/data/__init__.py b/src/data/__init__.py index 81370c5..d36c9de 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1 +1 @@ -from .dataloader import SyntheticSphere +from .dataloader import Spheres diff --git a/src/data/dataloader.py b/src/data/dataloader.py index 3d585c2..2b44a25 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -1,6 +1,6 @@ import albumentations as A import pytorch_lightning as pl -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset import wandb from utils import RandomPaste @@ -8,7 +8,7 @@ from utils import RandomPaste from .dataset import LabeledDataset, SyntheticDataset -class SyntheticSphere(pl.LightningDataModule): +class Spheres(pl.LightningDataModule): def __init__(self): super().__init__() @@ -25,7 +25,7 @@ class SyntheticSphere(pl.LightningDataModule): ) dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform) - # ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1))) return DataLoader( dataset, @@ -37,6 +37,7 @@ class SyntheticSphere(pl.LightningDataModule): def val_dataloader(self): dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) + dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) return DataLoader( dataset, diff --git a/src/data/dataset.py b/src/data/dataset.py index 01c50ea..d30931f 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -63,7 +63,7 @@ class LabeledDataset(Dataset): # convert image & mask to Tensor float in [0, 1] post_process = A.Compose( [ - # A.SmallestMaxSize(1024), + A.SmallestMaxSize(1024), A.ToFloat(max_value=255), ToTensorV2(), ], diff --git a/src/train.py b/src/train.py index 8e34144..d278141 100644 --- a/src/train.py +++ b/src/train.py @@ -1,11 +1,12 @@ import logging import pytorch_lightning as pl -from pytorch_lightning.callbacks import RichProgressBar +from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar from pytorch_lightning.loggers import WandbLogger import wandb -from unet import UNet +from data import Spheres +from unet import UNetModule CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", @@ -18,9 +19,9 @@ CONFIG = { "PIN_MEMORY": True, "BENCHMARK": True, "DEVICE": "gpu", - "WORKERS": 10, + "WORKERS": 14, "EPOCHS": 1, - "BATCH_SIZE": 32, + "BATCH_SIZE": 16 * 3, "LEARNING_RATE": 1e-4, "WEIGHT_DECAY": 1e-8, "MOMENTUM": 0.9, @@ -45,7 +46,7 @@ if __name__ == "__main__": pl.seed_everything(69420, workers=True) # Create network - net = UNet( + model = UNetModule( n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], batch_size=CONFIG["BATCH_SIZE"], @@ -54,27 +55,31 @@ if __name__ == "__main__": ) # log gradients and weights regularly - logger.watch(net, log="all") + logger.watch(model, log="all") # create checkpoint callback - checkpoint_callback = pl.ModelCheckpoint( + checkpoint_callback = ModelCheckpoint( dirpath="checkpoints", monitor="val/dice", ) + # Create the dataloaders + datamodule = Spheres() + # Create the trainer trainer = pl.Trainer( max_epochs=CONFIG["EPOCHS"], accelerator=CONFIG["DEVICE"], - # precision=16, benchmark=CONFIG["BENCHMARK"], - val_check_interval=100, - callbacks=RichProgressBar(), + # profiler="simple", + # precision=16, logger=logger, log_every_n_steps=1, + val_check_interval=100, + callbacks=RichProgressBar(), ) - trainer.fit(model=net) + trainer.fit(model=model, datamodule=datamodule) # stop wandb wandb.run.finish() diff --git a/src/unet/model.py b/src/unet/model.py index bddfb14..a5b8e2a 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -29,8 +29,6 @@ class UNet(nn.Module): def forward(self, x): skips = [] - x = x.to(self.device) - x = self.inc(x) for down in self.downs: diff --git a/src/unet/module.py b/src/unet/module.py index a08a383..7c07962 100644 --- a/src/unet/module.py +++ b/src/unet/module.py @@ -73,7 +73,7 @@ class UNetModule(pl.LightningModule): columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"] for i, (img, mask, pred, pred_bin) in enumerate( zip( # TODO: use comprehension list to zip the dictionnary - tensors["images"].cpu(), + tensors["data"].cpu(), tensors["ground_truth"].cpu(), tensors["prediction"].cpu(), tensors["binary"] @@ -121,7 +121,7 @@ class UNetModule(pl.LightningModule): if batch_idx % 50 == 0 or metrics["dice"] > 0.9: for i, (img, mask, pred, pred_bin) in enumerate( zip( # TODO: use comprehension list to zip the dictionnary - tensors["images"].cpu(), + tensors["data"].cpu(), tensors["ground_truth"].cpu(), tensors["prediction"].cpu(), tensors["binary"] @@ -150,11 +150,12 @@ class UNetModule(pl.LightningModule): ] ) - return metrics + return metrics, rows def validation_epoch_end(self, validation_outputs): # unpacking - metricss, rowss = validation_outputs + metricss = [v[0] for v in validation_outputs] + rowss = [v[1] for v in validation_outputs] # metrics flattening metrics = { @@ -162,7 +163,7 @@ class UNetModule(pl.LightningModule): "dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(), "bce": torch.stack([d["bce"] for d in metricss]).mean(), "mae": torch.stack([d["mae"] for d in metricss]).mean(), - "accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(), + "accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(), } # log metrics