mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-09-19 11:35:28 +00:00
feat: automatic batc/lr guessing sorta works
Former-commit-id: 346f1f55bab70df44bf15ab04c9a97f256e3d19c [formerly e027de4b57339dccc540ec11cfe81d5278c20d57] Former-commit-id: 9f3537abccca7ab3d433df318cc7acf6bfc610c4
This commit is contained in:
parent
f9ca8532a0
commit
5a74af6cdb
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -9,6 +9,7 @@ lightning_logs/
|
||||||
checkpoints/
|
checkpoints/
|
||||||
*.pth
|
*.pth
|
||||||
*.onnx
|
*.onnx
|
||||||
|
*.ckpt
|
||||||
|
|
||||||
*.png
|
*.png
|
||||||
*.jpg
|
*.jpg
|
||||||
|
|
59
src/train.py
59
src/train.py
|
@ -1,17 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import albumentations as A
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from src.utils.dataset import SphereDataset
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.paste import RandomPaste
|
|
||||||
|
|
||||||
CONFIG = {
|
CONFIG = {
|
||||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||||
|
@ -52,7 +48,7 @@ if __name__ == "__main__":
|
||||||
# seed random generators
|
# seed random generators
|
||||||
pl.seed_everything(69420, workers=True)
|
pl.seed_everything(69420, workers=True)
|
||||||
|
|
||||||
# 0. Create network
|
# Create network
|
||||||
net = UNet(
|
net = UNet(
|
||||||
n_channels=CONFIG["N_CHANNELS"],
|
n_channels=CONFIG["N_CHANNELS"],
|
||||||
n_classes=CONFIG["N_CLASSES"],
|
n_classes=CONFIG["N_CLASSES"],
|
||||||
|
@ -64,53 +60,13 @@ if __name__ == "__main__":
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(net, log="all")
|
logger.watch(net, log="all")
|
||||||
|
|
||||||
# 1. Create transforms
|
# Create the trainer
|
||||||
tf_train = A.Compose(
|
|
||||||
[
|
|
||||||
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
|
|
||||||
A.Flip(),
|
|
||||||
A.ColorJitter(),
|
|
||||||
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
|
|
||||||
A.GaussianBlur(),
|
|
||||||
A.ISONoise(),
|
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Create datasets
|
|
||||||
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
|
|
||||||
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
|
|
||||||
|
|
||||||
# 2.5. Create subset, if uncommented
|
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
|
||||||
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
|
||||||
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
|
||||||
|
|
||||||
# 3. Create data loaders
|
|
||||||
train_loader = DataLoader(
|
|
||||||
ds_train,
|
|
||||||
shuffle=True,
|
|
||||||
batch_size=CONFIG["BATCH_SIZE"],
|
|
||||||
num_workers=CONFIG["WORKERS"],
|
|
||||||
pin_memory=CONFIG["PIN_MEMORY"],
|
|
||||||
)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
ds_valid,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=True,
|
|
||||||
batch_size=1,
|
|
||||||
num_workers=CONFIG["WORKERS"],
|
|
||||||
pin_memory=CONFIG["PIN_MEMORY"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Create the trainer
|
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=CONFIG["EPOCHS"],
|
max_epochs=CONFIG["EPOCHS"],
|
||||||
accelerator=CONFIG["DEVICE"],
|
accelerator=CONFIG["DEVICE"],
|
||||||
# precision=16,
|
# precision=16,
|
||||||
auto_scale_batch_size="binsearch",
|
# auto_scale_batch_size="binsearch",
|
||||||
auto_lr_find=True,
|
# auto_lr_find=True,
|
||||||
benchmark=CONFIG["BENCHMARK"],
|
benchmark=CONFIG["BENCHMARK"],
|
||||||
val_check_interval=100,
|
val_check_interval=100,
|
||||||
callbacks=RichProgressBar(),
|
callbacks=RichProgressBar(),
|
||||||
|
@ -119,11 +75,8 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trainer.fit(
|
trainer.tune(net)
|
||||||
model=net,
|
trainer.fit(model=net)
|
||||||
train_dataloaders=train_loader,
|
|
||||||
val_dataloaders=val_loader,
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -2,10 +2,15 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
from src.utils.dataset import SphereDataset
|
||||||
from utils.dice import dice_coeff
|
from utils.dice import dice_coeff
|
||||||
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
from .blocks import *
|
from .blocks import *
|
||||||
|
|
||||||
|
@ -24,6 +29,9 @@ class UNet(pl.LightningModule):
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# log hyperparameters
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
# Network
|
# Network
|
||||||
self.inc = DoubleConv(n_channels, features[0])
|
self.inc = DoubleConv(n_channels, features[0])
|
||||||
|
|
||||||
|
@ -59,6 +67,42 @@ class UNet(pl.LightningModule):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
tf_train = A.Compose(
|
||||||
|
[
|
||||||
|
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||||
|
A.Flip(),
|
||||||
|
A.ColorJitter(),
|
||||||
|
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||||
|
A.GaussianBlur(),
|
||||||
|
A.ISONoise(),
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
|
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
ds_train,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
ds_valid,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
images, masks_true = batch
|
images, masks_true = batch
|
||||||
|
@ -109,8 +153,8 @@ class UNet(pl.LightningModule):
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
|
rows = []
|
||||||
if batch_idx < 6:
|
if batch_idx < 6:
|
||||||
rows = []
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.cpu(),
|
images.cpu(),
|
||||||
|
@ -157,11 +201,14 @@ class UNet(pl.LightningModule):
|
||||||
rows = list(itertools.chain.from_iterable(rowss))
|
rows = list(itertools.chain.from_iterable(rowss))
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
self.logger.log_table(
|
try:
|
||||||
key="val/predictions",
|
self.logger.log_table(
|
||||||
columns=columns,
|
key="val/predictions",
|
||||||
data=rows,
|
columns=columns,
|
||||||
)
|
data=rows,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
self.log_dict(
|
self.log_dict(
|
||||||
{
|
{
|
||||||
"val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
|
@ -229,7 +276,7 @@ class UNet(pl.LightningModule):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.RMSprop(
|
optimizer = torch.optim.RMSprop(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
lr=wandb.config.LEARNING_RATE,
|
lr=self.learning_rate,
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
momentum=wandb.config.MOMENTUM,
|
momentum=wandb.config.MOMENTUM,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue