refactor: making the code actually work
Former-commit-id: 302f77ef109ca44eef8c4ce6c5b3f59ba2891884 [formerly 5ae3d84ea691440283f73b5f12e2c85b7d95b191] Former-commit-id: 33cbfef93d9be0cefd526e1cdbd7c0435db4c613
This commit is contained in:
parent
81cbfd6212
commit
8d903e7ad1
|
@ -1 +1 @@
|
|||
from .dataloader import SyntheticSphere
|
||||
from .dataloader import Spheres
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
import wandb
|
||||
from utils import RandomPaste
|
||||
|
@ -8,7 +8,7 @@ from utils import RandomPaste
|
|||
from .dataset import LabeledDataset, SyntheticDataset
|
||||
|
||||
|
||||
class SyntheticSphere(pl.LightningDataModule):
|
||||
class Spheres(pl.LightningDataModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -25,7 +25,7 @@ class SyntheticSphere(pl.LightningDataModule):
|
|||
)
|
||||
|
||||
dataset = SyntheticDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=transform)
|
||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 10000 + 1)))
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
|
@ -37,6 +37,7 @@ class SyntheticSphere(pl.LightningDataModule):
|
|||
|
||||
def val_dataloader(self):
|
||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -63,7 +63,7 @@ class LabeledDataset(Dataset):
|
|||
# convert image & mask to Tensor float in [0, 1]
|
||||
post_process = A.Compose(
|
||||
[
|
||||
# A.SmallestMaxSize(1024),
|
||||
A.SmallestMaxSize(1024),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
|
|
27
src/train.py
27
src/train.py
|
@ -1,11 +1,12 @@
|
|||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
import wandb
|
||||
from unet import UNet
|
||||
from data import Spheres
|
||||
from unet import UNetModule
|
||||
|
||||
CONFIG = {
|
||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||
|
@ -18,9 +19,9 @@ CONFIG = {
|
|||
"PIN_MEMORY": True,
|
||||
"BENCHMARK": True,
|
||||
"DEVICE": "gpu",
|
||||
"WORKERS": 10,
|
||||
"WORKERS": 14,
|
||||
"EPOCHS": 1,
|
||||
"BATCH_SIZE": 32,
|
||||
"BATCH_SIZE": 16 * 3,
|
||||
"LEARNING_RATE": 1e-4,
|
||||
"WEIGHT_DECAY": 1e-8,
|
||||
"MOMENTUM": 0.9,
|
||||
|
@ -45,7 +46,7 @@ if __name__ == "__main__":
|
|||
pl.seed_everything(69420, workers=True)
|
||||
|
||||
# Create network
|
||||
net = UNet(
|
||||
model = UNetModule(
|
||||
n_channels=CONFIG["N_CHANNELS"],
|
||||
n_classes=CONFIG["N_CLASSES"],
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
|
@ -54,27 +55,31 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
# log gradients and weights regularly
|
||||
logger.watch(net, log="all")
|
||||
logger.watch(model, log="all")
|
||||
|
||||
# create checkpoint callback
|
||||
checkpoint_callback = pl.ModelCheckpoint(
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath="checkpoints",
|
||||
monitor="val/dice",
|
||||
)
|
||||
|
||||
# Create the dataloaders
|
||||
datamodule = Spheres()
|
||||
|
||||
# Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=CONFIG["EPOCHS"],
|
||||
accelerator=CONFIG["DEVICE"],
|
||||
# precision=16,
|
||||
benchmark=CONFIG["BENCHMARK"],
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
# profiler="simple",
|
||||
# precision=16,
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
)
|
||||
|
||||
trainer.fit(model=net)
|
||||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
# stop wandb
|
||||
wandb.run.finish()
|
||||
|
|
|
@ -29,8 +29,6 @@ class UNet(nn.Module):
|
|||
def forward(self, x):
|
||||
skips = []
|
||||
|
||||
x = x.to(self.device)
|
||||
|
||||
x = self.inc(x)
|
||||
|
||||
for down in self.downs:
|
||||
|
|
|
@ -73,7 +73,7 @@ class UNetModule(pl.LightningModule):
|
|||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip( # TODO: use comprehension list to zip the dictionnary
|
||||
tensors["images"].cpu(),
|
||||
tensors["data"].cpu(),
|
||||
tensors["ground_truth"].cpu(),
|
||||
tensors["prediction"].cpu(),
|
||||
tensors["binary"]
|
||||
|
@ -121,7 +121,7 @@ class UNetModule(pl.LightningModule):
|
|||
if batch_idx % 50 == 0 or metrics["dice"] > 0.9:
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip( # TODO: use comprehension list to zip the dictionnary
|
||||
tensors["images"].cpu(),
|
||||
tensors["data"].cpu(),
|
||||
tensors["ground_truth"].cpu(),
|
||||
tensors["prediction"].cpu(),
|
||||
tensors["binary"]
|
||||
|
@ -150,11 +150,12 @@ class UNetModule(pl.LightningModule):
|
|||
]
|
||||
)
|
||||
|
||||
return metrics
|
||||
return metrics, rows
|
||||
|
||||
def validation_epoch_end(self, validation_outputs):
|
||||
# unpacking
|
||||
metricss, rowss = validation_outputs
|
||||
metricss = [v[0] for v in validation_outputs]
|
||||
rowss = [v[1] for v in validation_outputs]
|
||||
|
||||
# metrics flattening
|
||||
metrics = {
|
||||
|
@ -162,7 +163,7 @@ class UNetModule(pl.LightningModule):
|
|||
"dice_bin": torch.stack([d["dice_bin"] for d in metricss]).mean(),
|
||||
"bce": torch.stack([d["bce"] for d in metricss]).mean(),
|
||||
"mae": torch.stack([d["mae"] for d in metricss]).mean(),
|
||||
"accuracy": torch.stack([d["accuracy"] for d in validation_outputs]).mean(),
|
||||
"accuracy": torch.stack([d["accuracy"] for d in metricss]).mean(),
|
||||
}
|
||||
|
||||
# log metrics
|
||||
|
|
Loading…
Reference in a new issue