diff --git a/src/train.py b/src/train.py index 2471320..c54988a 100644 --- a/src/train.py +++ b/src/train.py @@ -19,7 +19,7 @@ CONFIG = { "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", "DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/", "DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/", - "FEATURES": [64, 128, 256, 512], + "FEATURES": [16, 32, 64, 128], "N_CHANNELS": 3, "N_CLASSES": 1, "AMP": True, @@ -53,7 +53,13 @@ if __name__ == "__main__": pl.seed_everything(69420, workers=True) # 0. Create network - net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"]) + net = UNet( + n_channels=CONFIG["N_CHANNELS"], + n_classes=CONFIG["N_CLASSES"], + batch_size=CONFIG["BATCH_SIZE"], + learning_rate=CONFIG["LEARNING_RATE"], + features=CONFIG["FEATURES"], + ) # log gradients and weights regularly logger.watch(net, log="all") @@ -77,7 +83,7 @@ if __name__ == "__main__": ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"]) # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) @@ -104,9 +110,12 @@ if __name__ == "__main__": accelerator=CONFIG["DEVICE"], # precision=16, auto_scale_batch_size="binsearch", + auto_lr_find=True, benchmark=CONFIG["BENCHMARK"], val_check_interval=100, callbacks=RichProgressBar(), + logger=logger, + log_every_n_steps=1, ) try: diff --git a/src/unet/model.py b/src/unet/model.py index b9d6c18..be5712e 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -1,6 +1,5 @@ """ Full assembly of the parts to form the complete network """ -import numpy as np import pytorch_lightning as pl import wandb @@ -14,11 +13,16 @@ class_labels = { class UNet(pl.LightningModule): - def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): + def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]): super(UNet, self).__init__() + + # Hyperparameters self.n_channels = n_channels self.n_classes = n_classes + self.learning_rate = learning_rate + self.batch_size = batch_size + # Network self.inc = DoubleConv(n_channels, features[0]) self.downs = nn.ModuleList() @@ -39,6 +43,7 @@ class UNet(pl.LightningModule): skips = [] x = x.to(self.device) + x = self.inc(x) for down in self.downs: @@ -78,77 +83,97 @@ class UNet(pl.LightningModule): ), ) - wandb.log( - { - log_key: table, - } - ) + wandb.log({log_key: table}) # replace by self.log def training_step(self, batch, batch_idx): # unpacking images, masks_true = batch masks_true = masks_true.unsqueeze(1) - masks_pred = self(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - # compute metrics - loss = F.cross_entropy(masks_pred, masks_true) + # forward pass + masks_pred = self(images) + + # compute loss + bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + + # compute other metrics + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) accuracy = (masks_true == masks_pred_bin).float().mean() dice = dice_coeff(masks_pred_bin, masks_true) - self.log( - "train", + self.log_dict( { - "accuracy": accuracy, - "bce": loss, - "dice": dice, - "mae": mae, + "train/accuracy": accuracy, + "train/bce": bce, + "train/dice": dice, + "train/mae": mae, }, ) - return loss # , dice, accuracy, mae + return dict( + loss=bce, + dice=dice, + accuracy=accuracy, + mae=mae, + ) def validation_step(self, batch, batch_idx): # unpacking images, masks_true = batch masks_true = masks_true.unsqueeze(1) - masks_pred = self(images) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - # compute metrics - loss = F.cross_entropy(masks_pred, masks_true) - # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) - # accuracy = (masks_true == masks_pred_bin).float().mean() - # dice = dice_coeff(masks_pred_bin, masks_true) + # forward pass + masks_pred = self(images) + + # compute loss + bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true) + + # compute other metrics + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy = (masks_true == masks_pred_bin).float().mean() + dice = dice_coeff(masks_pred_bin, masks_true) if batch_idx == 0: self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") - return loss # , dice, accuracy, mae + return dict( + loss=bce, + dice=dice, + accuracy=accuracy, + mae=mae, + ) - # def validation_step_end(self, validation_outputs): - # # unpacking - # loss, dice, accuracy, mae = validation_outputs - # # optimizer = self.optimizers[0] - # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] + def validation_epoch_end(self, validation_outputs): + # unpacking + accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean() + loss = torch.stack([d["loss"] for d in validation_outputs]).mean() + dice = torch.stack([d["dice"] for d in validation_outputs]).mean() + mae = torch.stack([d["mae"] for d in validation_outputs]).mean() - # wandb.log( - # { - # # "train/learning_rate": learning_rate, - # "val/accuracy": accuracy, - # "val/bce": loss, - # "val/dice": dice, - # "val/mae": mae, - # } - # ) + # logging + wandb.log( + { + "val/accuracy": accuracy, + "val/bce": loss, + "val/dice": dice, + "val/mae": mae, + } + ) - # # export model to onnx - # dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) - # torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") - # artifact = wandb.Artifact("onnx", type="model") - # artifact.add_file(f"checkpoints/model.onnx") - # wandb.run.log_artifact(artifact) + # export model to pth + torch.save(self.state_dict(), f"checkpoints/model.pth") + artifact = wandb.Artifact("pth", type="model") + artifact.add_file(f"checkpoints/model.pth") + wandb.run.log_artifact(artifact) + + # export model to onnx + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) + torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") + artifact = wandb.Artifact("onnx", type="model") + artifact.add_file(f"checkpoints/model.onnx") + wandb.run.log_artifact(artifact) # def test_step(self, batch, batch_idx): # # unpacking @@ -199,10 +224,5 @@ class UNet(pl.LightningModule): weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, ) - # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - # optimizer, - # "max", - # patience=2, - # ) - return optimizer # , scheduler + return optimizer