mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: automatic batc/lr guessing sorta works
Former-commit-id: 346f1f55bab70df44bf15ab04c9a97f256e3d19c [formerly e027de4b57339dccc540ec11cfe81d5278c20d57] Former-commit-id: 9f3537abccca7ab3d433df318cc7acf6bfc610c4
This commit is contained in:
parent
f9ca8532a0
commit
5a74af6cdb
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -9,6 +9,7 @@ lightning_logs/
|
|||
checkpoints/
|
||||
*.pth
|
||||
*.onnx
|
||||
*.ckpt
|
||||
|
||||
*.png
|
||||
*.jpg
|
||||
|
|
59
src/train.py
59
src/train.py
|
@ -1,17 +1,13 @@
|
|||
import logging
|
||||
|
||||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
from src.utils.dataset import SphereDataset
|
||||
from unet import UNet
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
CONFIG = {
|
||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||
|
@ -52,7 +48,7 @@ if __name__ == "__main__":
|
|||
# seed random generators
|
||||
pl.seed_everything(69420, workers=True)
|
||||
|
||||
# 0. Create network
|
||||
# Create network
|
||||
net = UNet(
|
||||
n_channels=CONFIG["N_CHANNELS"],
|
||||
n_classes=CONFIG["N_CLASSES"],
|
||||
|
@ -64,53 +60,13 @@ if __name__ == "__main__":
|
|||
# log gradients and weights regularly
|
||||
logger.watch(net, log="all")
|
||||
|
||||
# 1. Create transforms
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
|
||||
|
||||
# 2.5. Create subset, if uncommented
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = DataLoader(
|
||||
ds_train,
|
||||
shuffle=True,
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
num_workers=CONFIG["WORKERS"],
|
||||
pin_memory=CONFIG["PIN_MEMORY"],
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
batch_size=1,
|
||||
num_workers=CONFIG["WORKERS"],
|
||||
pin_memory=CONFIG["PIN_MEMORY"],
|
||||
)
|
||||
|
||||
# 4. Create the trainer
|
||||
# Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=CONFIG["EPOCHS"],
|
||||
accelerator=CONFIG["DEVICE"],
|
||||
# precision=16,
|
||||
auto_scale_batch_size="binsearch",
|
||||
auto_lr_find=True,
|
||||
# auto_scale_batch_size="binsearch",
|
||||
# auto_lr_find=True,
|
||||
benchmark=CONFIG["BENCHMARK"],
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
|
@ -119,11 +75,8 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
try:
|
||||
trainer.fit(
|
||||
model=net,
|
||||
train_dataloaders=train_loader,
|
||||
val_dataloaders=val_loader,
|
||||
)
|
||||
trainer.tune(net)
|
||||
trainer.fit(model=net)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||
raise
|
||||
|
|
|
@ -2,10 +2,15 @@
|
|||
|
||||
import itertools
|
||||
|
||||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
from src.utils.dataset import SphereDataset
|
||||
from utils.dice import dice_coeff
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
from .blocks import *
|
||||
|
||||
|
@ -24,6 +29,9 @@ class UNet(pl.LightningModule):
|
|||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.inc = DoubleConv(n_channels, features[0])
|
||||
|
||||
|
@ -59,6 +67,42 @@ class UNet(pl.LightningModule):
|
|||
|
||||
return x
|
||||
|
||||
def train_dataloader(self):
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||
|
||||
return DataLoader(
|
||||
ds_train,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||
|
||||
return DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
images, masks_true = batch
|
||||
|
@ -109,8 +153,8 @@ class UNet(pl.LightningModule):
|
|||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
rows = []
|
||||
if batch_idx < 6:
|
||||
rows = []
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
|
@ -157,11 +201,14 @@ class UNet(pl.LightningModule):
|
|||
rows = list(itertools.chain.from_iterable(rowss))
|
||||
|
||||
# logging
|
||||
self.logger.log_table(
|
||||
key="val/predictions",
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
try:
|
||||
self.logger.log_table(
|
||||
key="val/predictions",
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
self.log_dict(
|
||||
{
|
||||
"val/accuracy": accuracy,
|
||||
|
@ -229,7 +276,7 @@ class UNet(pl.LightningModule):
|
|||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
lr=self.learning_rate,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue