feat: automatic batc/lr guessing sorta works

Former-commit-id: 346f1f55bab70df44bf15ab04c9a97f256e3d19c [formerly e027de4b57339dccc540ec11cfe81d5278c20d57]
Former-commit-id: 9f3537abccca7ab3d433df318cc7acf6bfc610c4
This commit is contained in:
Laurent Fainsin 2022-07-06 11:57:21 +02:00
parent f9ca8532a0
commit 5a74af6cdb
3 changed files with 61 additions and 60 deletions

1
.gitignore vendored
View file

@ -9,6 +9,7 @@ lightning_logs/
checkpoints/ checkpoints/
*.pth *.pth
*.onnx *.onnx
*.ckpt
*.png *.png
*.jpg *.jpg

View file

@ -1,17 +1,13 @@
import logging import logging
import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from albumentations.pytorch import ToTensorV2
from pytorch_lightning.callbacks import RichProgressBar from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import wandb import wandb
from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.paste import RandomPaste
CONFIG = { CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
@ -52,7 +48,7 @@ if __name__ == "__main__":
# seed random generators # seed random generators
pl.seed_everything(69420, workers=True) pl.seed_everything(69420, workers=True)
# 0. Create network # Create network
net = UNet( net = UNet(
n_channels=CONFIG["N_CHANNELS"], n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"], n_classes=CONFIG["N_CLASSES"],
@ -64,53 +60,13 @@ if __name__ == "__main__":
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(net, log="all") logger.watch(net, log="all")
# 1. Create transforms # Create the trainer
tf_train = A.Compose(
[
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
A.Flip(),
A.ColorJitter(),
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
# 2. Create datasets
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
# 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
# 3. Create data loaders
train_loader = DataLoader(
ds_train,
shuffle=True,
batch_size=CONFIG["BATCH_SIZE"],
num_workers=CONFIG["WORKERS"],
pin_memory=CONFIG["PIN_MEMORY"],
)
val_loader = DataLoader(
ds_valid,
shuffle=False,
drop_last=True,
batch_size=1,
num_workers=CONFIG["WORKERS"],
pin_memory=CONFIG["PIN_MEMORY"],
)
# 4. Create the trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"], max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"], accelerator=CONFIG["DEVICE"],
# precision=16, # precision=16,
auto_scale_batch_size="binsearch", # auto_scale_batch_size="binsearch",
auto_lr_find=True, # auto_lr_find=True,
benchmark=CONFIG["BENCHMARK"], benchmark=CONFIG["BENCHMARK"],
val_check_interval=100, val_check_interval=100,
callbacks=RichProgressBar(), callbacks=RichProgressBar(),
@ -119,11 +75,8 @@ if __name__ == "__main__":
) )
try: try:
trainer.fit( trainer.tune(net)
model=net, trainer.fit(model=net)
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")
raise raise

View file

@ -2,10 +2,15 @@
import itertools import itertools
import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import wandb import wandb
from src.utils.dataset import SphereDataset
from utils.dice import dice_coeff from utils.dice import dice_coeff
from utils.paste import RandomPaste
from .blocks import * from .blocks import *
@ -24,6 +29,9 @@ class UNet(pl.LightningModule):
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.batch_size = batch_size self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
# Network # Network
self.inc = DoubleConv(n_channels, features[0]) self.inc = DoubleConv(n_channels, features[0])
@ -59,6 +67,42 @@ class UNet(pl.LightningModule):
return x return x
def train_dataloader(self):
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
return DataLoader(
ds_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
return DataLoader(
ds_valid,
shuffle=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# unpacking # unpacking
images, masks_true = batch images, masks_true = batch
@ -109,8 +153,8 @@ class UNet(pl.LightningModule):
accuracy = (masks_true == masks_pred_bin).float().mean() accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true) dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx < 6:
rows = [] rows = []
if batch_idx < 6:
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( zip(
images.cpu(), images.cpu(),
@ -157,11 +201,14 @@ class UNet(pl.LightningModule):
rows = list(itertools.chain.from_iterable(rowss)) rows = list(itertools.chain.from_iterable(rowss))
# logging # logging
try:
self.logger.log_table( self.logger.log_table(
key="val/predictions", key="val/predictions",
columns=columns, columns=columns,
data=rows, data=rows,
) )
except:
pass
self.log_dict( self.log_dict(
{ {
"val/accuracy": accuracy, "val/accuracy": accuracy,
@ -229,7 +276,7 @@ class UNet(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.RMSprop( optimizer = torch.optim.RMSprop(
self.parameters(), self.parameters(),
lr=wandb.config.LEARNING_RATE, lr=self.learning_rate,
weight_decay=wandb.config.WEIGHT_DECAY, weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM, momentum=wandb.config.MOMENTUM,
) )