feat: kinda broken

Former-commit-id: 4cf02610721ba30c3dd1be6377daeeed907bc651 [formerly 52ef07ec8a123ddd362ac7c930eb6c915848e8b4]
Former-commit-id: 29fc18cae50625fd1f2868fc9696ca505f5648e2
This commit is contained in:
Laurent Fainsin 2022-07-05 15:18:31 +02:00
parent 982dfe99d7
commit e4562e2481
5 changed files with 153 additions and 151 deletions

1
.gitignore vendored
View file

@ -4,6 +4,7 @@ __pycache__/
wandb/
images/
lightning_logs/
checkpoints/
*.pth

39
poetry.lock generated
View file

@ -189,6 +189,17 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "commonmark"
version = "0.9.1"
description = "Python parser for the CommonMark Markdown spec"
category = "main"
optional = false
python-versions = "*"
[package.extras]
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
[[package]]
name = "cycler"
version = "0.11.0"
@ -881,7 +892,7 @@ python-versions = ">=3.6"
name = "pygments"
version = "2.12.0"
description = "Pygments is a syntax highlighting package written in Python."
category = "dev"
category = "main"
optional = false
python-versions = ">=3.6"
@ -1027,6 +1038,22 @@ requests = ">=2.0.0"
[package.extras]
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
[[package]]
name = "rich"
version = "12.4.4"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.6.3,<4.0.0"
[package.dependencies]
commonmark = ">=0.9.0,<0.10.0"
pygments = ">=2.6.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
[[package]]
name = "rsa"
version = "4.8"
@ -1446,7 +1473,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5"
content-hash = "416650c968a0021f7d64028f272464d96319c361a72888ae4cb3e2a602873832"
[metadata.files]
absl-py = [
@ -1652,6 +1679,10 @@ colorama = [
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
]
commonmark = [
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
]
cycler = [
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
@ -2419,6 +2450,10 @@ requests-oauthlib = [
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
]
rich = [
{file = "rich-12.4.4-py3-none-any.whl", hash = "sha256:d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec"},
{file = "rich-12.4.4.tar.gz", hash = "sha256:4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2"},
]
rsa = [
{file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"},
{file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"},

View file

@ -15,6 +15,7 @@ torch = "^1.11.0"
torchvision = "^0.12.0"
tqdm = "^4.64.0"
wandb = "^0.12.19"
rich = "^12.4.4"
[tool.poetry.dev-dependencies]
black = "^22.3.0"

View file

@ -3,8 +3,8 @@ import logging
import albumentations as A
import pytorch_lightning as pl
import torch
import yaml
from albumentations.pytorch import ToTensorV2
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
@ -13,8 +13,27 @@ from src.utils.dataset import SphereDataset
from unet import UNet
from utils.paste import RandomPaste
class_labels = {
1: "sphere",
CONFIG = {
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
"FEATURES": [64, 128, 256, 512],
"N_CHANNELS": 3,
"N_CLASSES": 1,
"AMP": True,
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 8,
"EPOCHS": 5,
"BATCH_SIZE": 16,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
"IMG_SIZE": 512,
"SPHERES": 5,
}
if __name__ == "__main__":
@ -24,28 +43,7 @@ if __name__ == "__main__":
# setup wandb
logger = WandbLogger(
project="U-Net",
config=dict(
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[64, 128, 256, 512],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
PIN_MEMORY=True,
BENCHMARK=True,
DEVICE="gpu",
WORKERS=8,
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8,
MOMENTUM=0.9,
IMG_SIZE=512,
SPHERES=5,
),
config=CONFIG,
settings=wandb.Settings(
code_dir="./src/",
),
@ -55,10 +53,7 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True)
# 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
# log the number of parameters of the model
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"])
# log gradients and weights regularly
logger.watch(net, log="all")
@ -66,88 +61,59 @@ if __name__ == "__main__":
# 1. Create transforms
tf_train = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
A.Flip(),
A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
tf_valid = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
# 2. Create datasets
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
# 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
# 3. Create data loaders
train_loader = DataLoader(
ds_train,
shuffle=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
batch_size=CONFIG["BATCH_SIZE"],
num_workers=CONFIG["WORKERS"],
pin_memory=CONFIG["PIN_MEMORY"],
)
val_loader = DataLoader(
ds_valid,
shuffle=False,
drop_last=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
test_loader = DataLoader(
ds_test,
shuffle=False,
drop_last=False,
batch_size=1,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
num_workers=CONFIG["WORKERS"],
pin_memory=CONFIG["PIN_MEMORY"],
)
# 4. Create the trainer
trainer = pl.Trainer(
max_epochs=wandb.config.EPOCHS,
accelerator="gpu",
precision=16,
max_epochs=CONFIG["EPOCHS"],
accelerator=CONFIG["DEVICE"],
# precision=16,
auto_scale_batch_size="binsearch",
benchmark=wandb.config.BENCHMARK,
benchmark=CONFIG["BENCHMARK"],
val_check_interval=100,
callbacks=RichProgressBar(),
)
# print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
# # wandb init log
# wandb.log(
# {
# "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
# },
# commit=False,
# )
try:
trainer.fit(
model=net,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
test_dataloaders=test_loader,
accelerator=wandb.config.DEVICE,
)
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")

View file

@ -1,7 +1,5 @@
""" Full assembly of the parts to form the complete network """
from xmlrpc.server import list_public_methods
import numpy as np
import pytorch_lightning as pl
@ -40,6 +38,7 @@ class UNet(pl.LightningModule):
def forward(self, x):
skips = []
x = x.to(self.device)
x = self.inc(x)
for down in self.downs:
@ -53,8 +52,7 @@ class UNet(pl.LightningModule):
return x
@staticmethod
def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key):
def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key):
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate(
@ -99,16 +97,17 @@ class UNet(pl.LightningModule):
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
wandb.log(
self.log(
"train",
{
"train/accuracy": accuracy,
"train/bce": loss,
"train/dice": dice,
"train/mae": mae,
}
"accuracy": accuracy,
"bce": loss,
"dice": dice,
"mae": mae,
},
)
return loss, dice, accuracy, mae
return loss # , dice, accuracy, mae
def validation_step(self, batch, batch_idx):
# unpacking
@ -119,79 +118,79 @@ class UNet(pl.LightningModule):
# compute metrics
loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
# accuracy = (masks_true == masks_pred_bin).float().mean()
# dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
return loss, dice, accuracy, mae
return loss # , dice, accuracy, mae
def validation_step_end(self, validation_outputs):
# unpacking
loss, dice, accuracy, mae = validation_outputs
optimizer = self.optimizers[0]
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
# def validation_step_end(self, validation_outputs):
# # unpacking
# loss, dice, accuracy, mae = validation_outputs
# # optimizer = self.optimizers[0]
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
wandb.log(
{
"train/learning_rate": learning_rate,
"val/accuracy": accuracy,
"val/bce": loss,
"val/dice": dice,
"val/mae": mae,
}
)
# wandb.log(
# {
# # "train/learning_rate": learning_rate,
# "val/accuracy": accuracy,
# "val/bce": loss,
# "val/dice": dice,
# "val/mae": mae,
# }
# )
# export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model.onnx")
wandb.run.log_artifact(artifact)
# # export model to onnx
# dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
# torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
# artifact = wandb.Artifact("onnx", type="model")
# artifact.add_file(f"checkpoints/model.onnx")
# wandb.run.log_artifact(artifact)
def test_step(self, batch, batch_idx):
# unpacking
images, masks_true = batch
masks_true = masks_true.unsqueeze(1)
masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# def test_step(self, batch, batch_idx):
# # unpacking
# images, masks_true = batch
# masks_true = masks_true.unsqueeze(1)
# masks_pred = self(images)
# masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics
loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true)
# # compute metrics
# loss = F.cross_entropy(masks_pred, masks_true)
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
# accuracy = (masks_true == masks_pred_bin).float().mean()
# dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
# if batch_idx == 0:
# self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
return loss, dice, accuracy, mae
# return loss, dice, accuracy, mae
def test_step_end(self, test_outputs):
# unpacking
list_loss, list_dice, list_accuracy, list_mae = test_outputs
# def test_step_end(self, test_outputs):
# # unpacking
# list_loss, list_dice, list_accuracy, list_mae = test_outputs
# averaging
loss = np.mean(list_loss)
dice = np.mean(list_dice)
accuracy = np.mean(list_accuracy)
mae = np.mean(list_mae)
# # averaging
# loss = np.mean(list_loss)
# dice = np.mean(list_dice)
# accuracy = np.mean(list_accuracy)
# mae = np.mean(list_mae)
# get learning rate
optimizer = self.optimizers[0]
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
# # # get learning rate
# # optimizer = self.optimizers[0]
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
wandb.log(
{
"train/learning_rate": learning_rate,
"val/accuracy": accuracy,
"val/bce": loss,
"val/dice": dice,
"val/mae": mae,
}
)
# wandb.log(
# {
# # "train/learning_rate": learning_rate,
# "test/accuracy": accuracy,
# "test/bce": loss,
# "test/dice": dice,
# "test/mae": mae,
# }
# )
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
@ -200,10 +199,10 @@ class UNet(pl.LightningModule):
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
"max",
patience=2,
)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
# optimizer,
# "max",
# patience=2,
# )
return optimizer, scheduler
return optimizer # , scheduler