feat: automatic batc/lr guessing sorta works

Former-commit-id: 346f1f55bab70df44bf15ab04c9a97f256e3d19c [formerly e027de4b57339dccc540ec11cfe81d5278c20d57]
Former-commit-id: 9f3537abccca7ab3d433df318cc7acf6bfc610c4
This commit is contained in:
Laurent Fainsin 2022-07-06 11:57:21 +02:00
parent f9ca8532a0
commit 5a74af6cdb
3 changed files with 61 additions and 60 deletions

1
.gitignore vendored
View file

@ -9,6 +9,7 @@ lightning_logs/
checkpoints/
*.pth
*.onnx
*.ckpt
*.png
*.jpg

View file

@ -1,17 +1,13 @@
import logging
import albumentations as A
import pytorch_lightning as pl
import torch
from albumentations.pytorch import ToTensorV2
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
import wandb
from src.utils.dataset import SphereDataset
from unet import UNet
from utils.paste import RandomPaste
CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
@ -52,7 +48,7 @@ if __name__ == "__main__":
# seed random generators
pl.seed_everything(69420, workers=True)
# 0. Create network
# Create network
net = UNet(
n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"],
@ -64,53 +60,13 @@ if __name__ == "__main__":
# log gradients and weights regularly
logger.watch(net, log="all")
# 1. Create transforms
tf_train = A.Compose(
[
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
A.Flip(),
A.ColorJitter(),
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
# 2. Create datasets
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
# 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
# 3. Create data loaders
train_loader = DataLoader(
ds_train,
shuffle=True,
batch_size=CONFIG["BATCH_SIZE"],
num_workers=CONFIG["WORKERS"],
pin_memory=CONFIG["PIN_MEMORY"],
)
val_loader = DataLoader(
ds_valid,
shuffle=False,
drop_last=True,
batch_size=1,
num_workers=CONFIG["WORKERS"],
pin_memory=CONFIG["PIN_MEMORY"],
)
# 4. Create the trainer
# Create the trainer
trainer = pl.Trainer(
max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"],
# precision=16,
auto_scale_batch_size="binsearch",
auto_lr_find=True,
# auto_scale_batch_size="binsearch",
# auto_lr_find=True,
benchmark=CONFIG["BENCHMARK"],
val_check_interval=100,
callbacks=RichProgressBar(),
@ -119,11 +75,8 @@ if __name__ == "__main__":
)
try:
trainer.fit(
model=net,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
trainer.tune(net)
trainer.fit(model=net)
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")
raise

View file

@ -2,10 +2,15 @@
import itertools
import albumentations as A
import pytorch_lightning as pl
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import wandb
from src.utils.dataset import SphereDataset
from utils.dice import dice_coeff
from utils.paste import RandomPaste
from .blocks import *
@ -24,6 +29,9 @@ class UNet(pl.LightningModule):
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
# Network
self.inc = DoubleConv(n_channels, features[0])
@ -59,6 +67,42 @@ class UNet(pl.LightningModule):
return x
def train_dataloader(self):
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
return DataLoader(
ds_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
return DataLoader(
ds_valid,
shuffle=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def training_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
@ -109,8 +153,8 @@ class UNet(pl.LightningModule):
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
rows = []
if batch_idx < 6:
rows = []
for i, (img, mask, pred, pred_bin) in enumerate(
zip(
images.cpu(),
@ -157,11 +201,14 @@ class UNet(pl.LightningModule):
rows = list(itertools.chain.from_iterable(rowss))
# logging
self.logger.log_table(
key="val/predictions",
columns=columns,
data=rows,
)
try:
self.logger.log_table(
key="val/predictions",
columns=columns,
data=rows,
)
except:
pass
self.log_dict(
{
"val/accuracy": accuracy,
@ -229,7 +276,7 @@ class UNet(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=wandb.config.LEARNING_RATE,
lr=self.learning_rate,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)