refactor: making the code actually work
Former-commit-id: 302f77ef109ca44eef8c4ce6c5b3f59ba2891884 [formerly 5ae3d84ea691440283f73b5f12e2c85b7d95b191] Former-commit-id: 33cbfef93d9be0cefd526e1cdbd7c0435db4c613
This commit is contained in:
parent
81cbfd6212
commit
8d903e7ad1
|
@ -1 +1 @@
|
||||||
from .dataloader import SyntheticSphere
|
from .dataloader import Spheres
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Subset
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from utils import RandomPaste
|
from utils import RandomPaste
|
||||||
|
@ -8,7 +8,7 @@ from utils import RandomPaste
|
||||||
from .dataset import LabeledDataset, SyntheticDataset
|
from .dataset import LabeledDataset, SyntheticDataset
|
||||||
|
|
||||||
|
|
||||||
class SyntheticSphere(pl.LightningDataModule):
|
class Spheres(pl.LightningDataModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class SyntheticSphere(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -37,6 +37,7 @@ class SyntheticSphere(pl.LightningDataModule):
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
|
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
|
|
@ -63,7 +63,7 @@ class LabeledDataset(Dataset):
|
||||||
# convert image & mask to Tensor float in [0, 1]
|
# convert image & mask to Tensor float in [0, 1]
|
||||||
post_process = A.Compose(
|
post_process = A.Compose(
|
||||||
[
|
[
|
||||||
# A.SmallestMaxSize(1024),
|
A.SmallestMaxSize(1024),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
|
|
27
src/train.py
27
src/train.py
|
@ -1,11 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks import RichProgressBar
|
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
from unet import UNet
|
from data import Spheres
|
||||||
|
from unet import UNetModule
|
||||||
|
|
||||||
CONFIG = {
|
CONFIG = {
|
||||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||||
|
@ -18,9 +19,9 @@ CONFIG = {
|
||||||
"PIN_MEMORY": True,
|
"PIN_MEMORY": True,
|
||||||
"BENCHMARK": True,
|
"BENCHMARK": True,
|
||||||
"DEVICE": "gpu",
|
"DEVICE": "gpu",
|
||||||
"WORKERS": 10,
|
"WORKERS": 14,
|
||||||
"EPOCHS": 1,
|
"EPOCHS": 1,
|
||||||
"BATCH_SIZE": 32,
|
"BATCH_SIZE": 16 * 3,
|
||||||
"LEARNING_RATE": 1e-4,
|
"LEARNING_RATE": 1e-4,
|
||||||
"WEIGHT_DECAY": 1e-8,
|
"WEIGHT_DECAY": 1e-8,
|
||||||
"MOMENTUM": 0.9,
|
"MOMENTUM": 0.9,
|
||||||
|
@ -45,7 +46,7 @@ if __name__ == "__main__":
|
||||||
pl.seed_everything(69420, workers=True)
|
pl.seed_everything(69420, workers=True)
|
||||||
|
|
||||||
# Create network
|
# Create network
|
||||||
net = UNet(
|
model = UNetModule(
|
||||||
n_channels=CONFIG["N_CHANNELS"],
|
n_channels=CONFIG["N_CHANNELS"],
|
||||||
n_classes=CONFIG["N_CLASSES"],
|
n_classes=CONFIG["N_CLASSES"],
|
||||||
batch_size=CONFIG["BATCH_SIZE"],
|
batch_size=CONFIG["BATCH_SIZE"],
|
||||||
|
@ -54,27 +55,31 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(net, log="all")
|
logger.watch(model, log="all")
|
||||||
|
|
||||||
# create checkpoint callback
|
# create checkpoint callback
|
||||||
checkpoint_callback = pl.ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
dirpath="checkpoints",
|
dirpath="checkpoints",
|
||||||
monitor="val/dice",
|
monitor="val/dice",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create the dataloaders
|
||||||
|
datamodule = Spheres()
|
||||||
|
|
||||||
# Create the trainer
|
# Create the trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=CONFIG["EPOCHS"],
|
max_epochs=CONFIG["EPOCHS"],
|
||||||
accelerator=CONFIG["DEVICE"],
|
accelerator=CONFIG["DEVICE"],
|
||||||
# precision=16,
|
|
||||||
benchmark=CONFIG["BENCHMARK"],
|
benchmark=CONFIG["BENCHMARK"],
|
||||||
val_check_interval=100,
|
# profiler="simple",
|
||||||
callbacks=RichProgressBar(),
|
# precision=16,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
log_every_n_steps=1,
|
log_every_n_steps=1,
|
||||||
|
val_check_interval=100,
|
||||||
|
callbacks=RichProgressBar(),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model=net)
|
trainer.fit(model=model, datamodule=datamodule)
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
|
@ -29,8 +29,6 @@ class UNet(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
skips = []
|
skips = []
|
||||||
|
|
||||||
x = x.to(self.device)
|
|
||||||
|
|
||||||
x = self.inc(x)
|
x = self.inc(x)
|
||||||
|
|
||||||
for down in self.downs:
|
for down in self.downs:
|
||||||
|
|
|
@ -73,7 +73,7 @@ class UNetModule(pl.LightningModule):
|
||||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip( # TODO: use comprehension list to zip the dictionnary
|
zip( # TODO: use comprehension list to zip the dictionnary
|
||||||
tensors["images"].cpu(),
|
tensors["data"].cpu(),
|
||||||
tensors["ground_truth"].cpu(),
|
tensors["ground_truth"].cpu(),
|
||||||
tensors["prediction"].cpu(),
|
tensors["prediction"].cpu(),
|
||||||
tensors["binary"]
|
tensors["binary"]
|
||||||
|
@ -121,7 +121,7 @@ class UNetModule(pl.LightningModule):
|
||||||
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip( # TODO: use comprehension list to zip the dictionnary
|
zip( # TODO: use comprehension list to zip the dictionnary
|
||||||
tensors["images"].cpu(),
|
tensors["data"].cpu(),
|
||||||
tensors["ground_truth"].cpu(),
|
tensors["ground_truth"].cpu(),
|
||||||
tensors["prediction"].cpu(),
|
tensors["prediction"].cpu(),
|
||||||
tensors["binary"]
|
tensors["binary"]
|
||||||
|
@ -150,11 +150,12 @@ class UNetModule(pl.LightningModule):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return metrics
|
return metrics, rows
|
||||||
|
|
||||||
def validation_epoch_end(self, validation_outputs):
|
def validation_epoch_end(self, validation_outputs):
|
||||||
# unpacking
|
# unpacking
|
||||||
metricss, rowss = validation_outputs
|
metricss = [v[0] for v in validation_outputs]
|
||||||
|
rowss = [v[1] for v in validation_outputs]
|
||||||
|
|
||||||
# metrics flattening
|
# metrics flattening
|
||||||
metrics = {
|
metrics = {
|
||||||
|
@ -162,7 +163,7 @@ class UNetModule(pl.LightningModule):
|
||||||
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
||||||
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
||||||
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
||||||
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(),
|
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# log metrics
|
# log metrics
|
||||||
|
|
Loading…
Reference in a new issue