mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: working logging, auto_batch/lr still not working
Former-commit-id: 29d4536eb182f84eb2cc9a4e31f31bf19a4ca272 [formerly f5fd5eec9394b81f15986fb6cbabf675b2f05c04] Former-commit-id: 3de00ee718a761221c1934b7cbaaa0ad5487856d
This commit is contained in:
parent
e4562e2481
commit
40ea1c3191
15
src/train.py
15
src/train.py
|
@ -19,7 +19,7 @@ CONFIG = {
|
||||||
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
||||||
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
|
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
"FEATURES": [64, 128, 256, 512],
|
"FEATURES": [16, 32, 64, 128],
|
||||||
"N_CHANNELS": 3,
|
"N_CHANNELS": 3,
|
||||||
"N_CLASSES": 1,
|
"N_CLASSES": 1,
|
||||||
"AMP": True,
|
"AMP": True,
|
||||||
|
@ -53,7 +53,13 @@ if __name__ == "__main__":
|
||||||
pl.seed_everything(69420, workers=True)
|
pl.seed_everything(69420, workers=True)
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"])
|
net = UNet(
|
||||||
|
n_channels=CONFIG["N_CHANNELS"],
|
||||||
|
n_classes=CONFIG["N_CLASSES"],
|
||||||
|
batch_size=CONFIG["BATCH_SIZE"],
|
||||||
|
learning_rate=CONFIG["LEARNING_RATE"],
|
||||||
|
features=CONFIG["FEATURES"],
|
||||||
|
)
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(net, log="all")
|
logger.watch(net, log="all")
|
||||||
|
@ -77,7 +83,7 @@ if __name__ == "__main__":
|
||||||
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
|
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
|
||||||
|
|
||||||
# 2.5. Create subset, if uncommented
|
# 2.5. Create subset, if uncommented
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||||
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||||
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||||
|
|
||||||
|
@ -104,9 +110,12 @@ if __name__ == "__main__":
|
||||||
accelerator=CONFIG["DEVICE"],
|
accelerator=CONFIG["DEVICE"],
|
||||||
# precision=16,
|
# precision=16,
|
||||||
auto_scale_batch_size="binsearch",
|
auto_scale_batch_size="binsearch",
|
||||||
|
auto_lr_find=True,
|
||||||
benchmark=CONFIG["BENCHMARK"],
|
benchmark=CONFIG["BENCHMARK"],
|
||||||
val_check_interval=100,
|
val_check_interval=100,
|
||||||
callbacks=RichProgressBar(),
|
callbacks=RichProgressBar(),
|
||||||
|
logger=logger,
|
||||||
|
log_every_n_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
""" Full assembly of the parts to form the complete network """
|
""" Full assembly of the parts to form the complete network """
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
import wandb
|
import wandb
|
||||||
|
@ -14,11 +13,16 @@ class_labels = {
|
||||||
|
|
||||||
|
|
||||||
class UNet(pl.LightningModule):
|
class UNet(pl.LightningModule):
|
||||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
self.n_channels = n_channels
|
self.n_channels = n_channels
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
# Network
|
||||||
self.inc = DoubleConv(n_channels, features[0])
|
self.inc = DoubleConv(n_channels, features[0])
|
||||||
|
|
||||||
self.downs = nn.ModuleList()
|
self.downs = nn.ModuleList()
|
||||||
|
@ -39,6 +43,7 @@ class UNet(pl.LightningModule):
|
||||||
skips = []
|
skips = []
|
||||||
|
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
|
|
||||||
x = self.inc(x)
|
x = self.inc(x)
|
||||||
|
|
||||||
for down in self.downs:
|
for down in self.downs:
|
||||||
|
@ -78,77 +83,97 @@ class UNet(pl.LightningModule):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
wandb.log(
|
wandb.log({log_key: table}) # replace by self.log
|
||||||
{
|
|
||||||
log_key: table,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
images, masks_true = batch
|
images, masks_true = batch
|
||||||
masks_true = masks_true.unsqueeze(1)
|
masks_true = masks_true.unsqueeze(1)
|
||||||
masks_pred = self(images)
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
|
|
||||||
# compute metrics
|
# forward pass
|
||||||
loss = F.cross_entropy(masks_pred, masks_true)
|
masks_pred = self(images)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||||
|
|
||||||
|
# compute other metrics
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
self.log(
|
self.log_dict(
|
||||||
"train",
|
|
||||||
{
|
{
|
||||||
"accuracy": accuracy,
|
"train/accuracy": accuracy,
|
||||||
"bce": loss,
|
"train/bce": bce,
|
||||||
"dice": dice,
|
"train/dice": dice,
|
||||||
"mae": mae,
|
"train/mae": mae,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss # , dice, accuracy, mae
|
return dict(
|
||||||
|
loss=bce,
|
||||||
|
dice=dice,
|
||||||
|
accuracy=accuracy,
|
||||||
|
mae=mae,
|
||||||
|
)
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
images, masks_true = batch
|
images, masks_true = batch
|
||||||
masks_true = masks_true.unsqueeze(1)
|
masks_true = masks_true.unsqueeze(1)
|
||||||
masks_pred = self(images)
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
|
||||||
|
|
||||||
# compute metrics
|
# forward pass
|
||||||
loss = F.cross_entropy(masks_pred, masks_true)
|
masks_pred = self(images)
|
||||||
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
|
||||||
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
# compute loss
|
||||||
# dice = dice_coeff(masks_pred_bin, masks_true)
|
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||||
|
|
||||||
|
# compute other metrics
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
if batch_idx == 0:
|
if batch_idx == 0:
|
||||||
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
|
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
|
||||||
|
|
||||||
return loss # , dice, accuracy, mae
|
return dict(
|
||||||
|
loss=bce,
|
||||||
|
dice=dice,
|
||||||
|
accuracy=accuracy,
|
||||||
|
mae=mae,
|
||||||
|
)
|
||||||
|
|
||||||
# def validation_step_end(self, validation_outputs):
|
def validation_epoch_end(self, validation_outputs):
|
||||||
# # unpacking
|
# unpacking
|
||||||
# loss, dice, accuracy, mae = validation_outputs
|
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
||||||
# # optimizer = self.optimizers[0]
|
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
||||||
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
dice = torch.stack([d["dice"] for d in validation_outputs]).mean()
|
||||||
|
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
||||||
|
|
||||||
# wandb.log(
|
# logging
|
||||||
# {
|
wandb.log(
|
||||||
# # "train/learning_rate": learning_rate,
|
{
|
||||||
# "val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
# "val/bce": loss,
|
"val/bce": loss,
|
||||||
# "val/dice": dice,
|
"val/dice": dice,
|
||||||
# "val/mae": mae,
|
"val/mae": mae,
|
||||||
# }
|
}
|
||||||
# )
|
)
|
||||||
|
|
||||||
# # export model to onnx
|
# export model to pth
|
||||||
# dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
||||||
# torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
# artifact = wandb.Artifact("onnx", type="model")
|
artifact.add_file(f"checkpoints/model.pth")
|
||||||
# artifact.add_file(f"checkpoints/model.onnx")
|
wandb.run.log_artifact(artifact)
|
||||||
# wandb.run.log_artifact(artifact)
|
|
||||||
|
# export model to onnx
|
||||||
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||||
|
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||||
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
|
artifact.add_file(f"checkpoints/model.onnx")
|
||||||
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# def test_step(self, batch, batch_idx):
|
# def test_step(self, batch, batch_idx):
|
||||||
# # unpacking
|
# # unpacking
|
||||||
|
@ -199,10 +224,5 @@ class UNet(pl.LightningModule):
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
momentum=wandb.config.MOMENTUM,
|
momentum=wandb.config.MOMENTUM,
|
||||||
)
|
)
|
||||||
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
||||||
# optimizer,
|
|
||||||
# "max",
|
|
||||||
# patience=2,
|
|
||||||
# )
|
|
||||||
|
|
||||||
return optimizer # , scheduler
|
return optimizer
|
||||||
|
|
Loading…
Reference in a new issue