diff --git a/.gitignore b/.gitignore index 40eb468..d8b8e64 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ lightning_logs/ checkpoints/ *.pth *.onnx +*.ckpt *.png *.jpg diff --git a/src/train.py b/src/train.py index c54988a..656eea4 100644 --- a/src/train.py +++ b/src/train.py @@ -1,17 +1,13 @@ import logging -import albumentations as A import pytorch_lightning as pl import torch -from albumentations.pytorch import ToTensorV2 from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader import wandb -from src.utils.dataset import SphereDataset from unet import UNet -from utils.paste import RandomPaste CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", @@ -52,7 +48,7 @@ if __name__ == "__main__": # seed random generators pl.seed_everything(69420, workers=True) - # 0. Create network + # Create network net = UNet( n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], @@ -64,53 +60,13 @@ if __name__ == "__main__": # log gradients and weights regularly logger.watch(net, log="all") - # 1. Create transforms - tf_train = A.Compose( - [ - A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]), - A.Flip(), - A.ColorJitter(), - RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]), - A.GaussianBlur(), - A.ISONoise(), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - - # 2. Create datasets - ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train) - ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) - - # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) - # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) - # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) - - # 3. Create data loaders - train_loader = DataLoader( - ds_train, - shuffle=True, - batch_size=CONFIG["BATCH_SIZE"], - num_workers=CONFIG["WORKERS"], - pin_memory=CONFIG["PIN_MEMORY"], - ) - val_loader = DataLoader( - ds_valid, - shuffle=False, - drop_last=True, - batch_size=1, - num_workers=CONFIG["WORKERS"], - pin_memory=CONFIG["PIN_MEMORY"], - ) - - # 4. Create the trainer + # Create the trainer trainer = pl.Trainer( max_epochs=CONFIG["EPOCHS"], accelerator=CONFIG["DEVICE"], # precision=16, - auto_scale_batch_size="binsearch", - auto_lr_find=True, + # auto_scale_batch_size="binsearch", + # auto_lr_find=True, benchmark=CONFIG["BENCHMARK"], val_check_interval=100, callbacks=RichProgressBar(), @@ -119,11 +75,8 @@ if __name__ == "__main__": ) try: - trainer.fit( - model=net, - train_dataloaders=train_loader, - val_dataloaders=val_loader, - ) + trainer.tune(net) + trainer.fit(model=net) except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") raise diff --git a/src/unet/model.py b/src/unet/model.py index 5f613f4..ccc036c 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -2,10 +2,15 @@ import itertools +import albumentations as A import pytorch_lightning as pl +from albumentations.pytorch import ToTensorV2 +from torch.utils.data import DataLoader import wandb +from src.utils.dataset import SphereDataset from utils.dice import dice_coeff +from utils.paste import RandomPaste from .blocks import * @@ -24,6 +29,9 @@ class UNet(pl.LightningModule): self.learning_rate = learning_rate self.batch_size = batch_size + # log hyperparameters + self.save_hyperparameters() + # Network self.inc = DoubleConv(n_channels, features[0]) @@ -59,6 +67,42 @@ class UNet(pl.LightningModule): return x + def train_dataloader(self): + tf_train = A.Compose( + [ + A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), + A.Flip(), + A.ColorJitter(), + RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), + A.GaussianBlur(), + A.ISONoise(), + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + + ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) + ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) + + return DataLoader( + ds_train, + batch_size=self.batch_size, + shuffle=True, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) + + def val_dataloader(self): + ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) + + return DataLoader( + ds_valid, + shuffle=False, + batch_size=1, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) + def training_step(self, batch, batch_idx): # unpacking images, masks_true = batch @@ -109,8 +153,8 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) + rows = [] if batch_idx < 6: - rows = [] for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(), @@ -157,11 +201,14 @@ class UNet(pl.LightningModule): rows = list(itertools.chain.from_iterable(rowss)) # logging - self.logger.log_table( - key="val/predictions", - columns=columns, - data=rows, - ) + try: + self.logger.log_table( + key="val/predictions", + columns=columns, + data=rows, + ) + except: + pass self.log_dict( { "val/accuracy": accuracy, @@ -229,7 +276,7 @@ class UNet(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.RMSprop( self.parameters(), - lr=wandb.config.LEARNING_RATE, + lr=self.learning_rate, weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, )