feat: working logging, auto_batch/lr still not working

Former-commit-id: 29d4536eb182f84eb2cc9a4e31f31bf19a4ca272 [formerly f5fd5eec9394b81f15986fb6cbabf675b2f05c04]
Former-commit-id: 3de00ee718a761221c1934b7cbaaa0ad5487856d
This commit is contained in:
Laurent Fainsin 2022-07-05 22:31:38 +02:00
parent e4562e2481
commit 40ea1c3191
2 changed files with 84 additions and 55 deletions

View file

@ -19,7 +19,7 @@ CONFIG = {
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
"FEATURES": [64, 128, 256, 512], "FEATURES": [16, 32, 64, 128],
"N_CHANNELS": 3, "N_CHANNELS": 3,
"N_CLASSES": 1, "N_CLASSES": 1,
"AMP": True, "AMP": True,
@ -53,7 +53,13 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True) pl.seed_everything(69420, workers=True)
# 0. Create network # 0. Create network
net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"]) net = UNet(
n_channels=CONFIG["N_CHANNELS"],
n_classes=CONFIG["N_CLASSES"],
batch_size=CONFIG["BATCH_SIZE"],
learning_rate=CONFIG["LEARNING_RATE"],
features=CONFIG["FEATURES"],
)
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(net, log="all") logger.watch(net, log="all")
@ -77,7 +83,7 @@ if __name__ == "__main__":
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
# 2.5. Create subset, if uncommented # 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
@ -104,9 +110,12 @@ if __name__ == "__main__":
accelerator=CONFIG["DEVICE"], accelerator=CONFIG["DEVICE"],
# precision=16, # precision=16,
auto_scale_batch_size="binsearch", auto_scale_batch_size="binsearch",
auto_lr_find=True,
benchmark=CONFIG["BENCHMARK"], benchmark=CONFIG["BENCHMARK"],
val_check_interval=100, val_check_interval=100,
callbacks=RichProgressBar(), callbacks=RichProgressBar(),
logger=logger,
log_every_n_steps=1,
) )
try: try:

View file

@ -1,6 +1,5 @@
""" Full assembly of the parts to form the complete network """ """ Full assembly of the parts to form the complete network """
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb import wandb
@ -14,11 +13,16 @@ class_labels = {
class UNet(pl.LightningModule): class UNet(pl.LightningModule):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
super(UNet, self).__init__() super(UNet, self).__init__()
# Hyperparameters
self.n_channels = n_channels self.n_channels = n_channels
self.n_classes = n_classes self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# Network
self.inc = DoubleConv(n_channels, features[0]) self.inc = DoubleConv(n_channels, features[0])
self.downs = nn.ModuleList() self.downs = nn.ModuleList()
@ -39,6 +43,7 @@ class UNet(pl.LightningModule):
skips = [] skips = []
x = x.to(self.device) x = x.to(self.device)
x = self.inc(x) x = self.inc(x)
for down in self.downs: for down in self.downs:
@ -78,77 +83,97 @@ class UNet(pl.LightningModule):
), ),
) )
wandb.log( wandb.log({log_key: table}) # replace by self.log
{
log_key: table,
}
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# unpacking # unpacking
images, masks_true = batch images, masks_true = batch
masks_true = masks_true.unsqueeze(1) masks_true = masks_true.unsqueeze(1)
masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics # forward pass
loss = F.cross_entropy(masks_pred, masks_true) masks_pred = self(images)
# compute loss
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
# compute other metrics
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean() accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true) dice = dice_coeff(masks_pred_bin, masks_true)
self.log( self.log_dict(
"train",
{ {
"accuracy": accuracy, "train/accuracy": accuracy,
"bce": loss, "train/bce": bce,
"dice": dice, "train/dice": dice,
"mae": mae, "train/mae": mae,
}, },
) )
return loss # , dice, accuracy, mae return dict(
loss=bce,
dice=dice,
accuracy=accuracy,
mae=mae,
)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# unpacking # unpacking
images, masks_true = batch images, masks_true = batch
masks_true = masks_true.unsqueeze(1) masks_true = masks_true.unsqueeze(1)
masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics # forward pass
loss = F.cross_entropy(masks_pred, masks_true) masks_pred = self(images)
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
# accuracy = (masks_true == masks_pred_bin).float().mean() # compute loss
# dice = dice_coeff(masks_pred_bin, masks_true) bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
# compute other metrics
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0: if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
return loss # , dice, accuracy, mae return dict(
loss=bce,
dice=dice,
accuracy=accuracy,
mae=mae,
)
# def validation_step_end(self, validation_outputs): def validation_epoch_end(self, validation_outputs):
# # unpacking # unpacking
# loss, dice, accuracy, mae = validation_outputs accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
# # optimizer = self.optimizers[0] loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] dice = torch.stack([d["dice"] for d in validation_outputs]).mean()
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
# wandb.log( # logging
# { wandb.log(
# # "train/learning_rate": learning_rate, {
# "val/accuracy": accuracy, "val/accuracy": accuracy,
# "val/bce": loss, "val/bce": loss,
# "val/dice": dice, "val/dice": dice,
# "val/mae": mae, "val/mae": mae,
# } }
# ) )
# # export model to onnx # export model to pth
# dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) torch.save(self.state_dict(), f"checkpoints/model.pth")
# torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") artifact = wandb.Artifact("pth", type="model")
# artifact = wandb.Artifact("onnx", type="model") artifact.add_file(f"checkpoints/model.pth")
# artifact.add_file(f"checkpoints/model.onnx") wandb.run.log_artifact(artifact)
# wandb.run.log_artifact(artifact)
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
# def test_step(self, batch, batch_idx): # def test_step(self, batch, batch_idx):
# # unpacking # # unpacking
@ -199,10 +224,5 @@ class UNet(pl.LightningModule):
weight_decay=wandb.config.WEIGHT_DECAY, weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM, momentum=wandb.config.MOMENTUM,
) )
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
# optimizer,
# "max",
# patience=2,
# )
return optimizer # , scheduler return optimizer