feat: kinda broken

Former-commit-id: 4cf02610721ba30c3dd1be6377daeeed907bc651 [formerly 52ef07ec8a123ddd362ac7c930eb6c915848e8b4]
Former-commit-id: 29fc18cae50625fd1f2868fc9696ca505f5648e2
This commit is contained in:
Laurent Fainsin 2022-07-05 15:18:31 +02:00
parent 982dfe99d7
commit e4562e2481
5 changed files with 153 additions and 151 deletions

1
.gitignore vendored
View file

@ -4,6 +4,7 @@ __pycache__/
wandb/ wandb/
images/ images/
lightning_logs/
checkpoints/ checkpoints/
*.pth *.pth

39
poetry.lock generated
View file

@ -189,6 +189,17 @@ category = "main"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "commonmark"
version = "0.9.1"
description = "Python parser for the CommonMark Markdown spec"
category = "main"
optional = false
python-versions = "*"
[package.extras]
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
[[package]] [[package]]
name = "cycler" name = "cycler"
version = "0.11.0" version = "0.11.0"
@ -881,7 +892,7 @@ python-versions = ">=3.6"
name = "pygments" name = "pygments"
version = "2.12.0" version = "2.12.0"
description = "Pygments is a syntax highlighting package written in Python." description = "Pygments is a syntax highlighting package written in Python."
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
@ -1027,6 +1038,22 @@ requests = ">=2.0.0"
[package.extras] [package.extras]
rsa = ["oauthlib[signedtoken] (>=3.0.0)"] rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
[[package]]
name = "rich"
version = "12.4.4"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
category = "main"
optional = false
python-versions = ">=3.6.3,<4.0.0"
[package.dependencies]
commonmark = ">=0.9.0,<0.10.0"
pygments = ">=2.6.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
[[package]] [[package]]
name = "rsa" name = "rsa"
version = "4.8" version = "4.8"
@ -1446,7 +1473,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8,<3.11" python-versions = ">=3.8,<3.11"
content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5" content-hash = "416650c968a0021f7d64028f272464d96319c361a72888ae4cb3e2a602873832"
[metadata.files] [metadata.files]
absl-py = [ absl-py = [
@ -1652,6 +1679,10 @@ colorama = [
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
] ]
commonmark = [
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
]
cycler = [ cycler = [
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"}, {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"}, {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
@ -2419,6 +2450,10 @@ requests-oauthlib = [
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
] ]
rich = [
{file = "rich-12.4.4-py3-none-any.whl", hash = "sha256:d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec"},
{file = "rich-12.4.4.tar.gz", hash = "sha256:4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2"},
]
rsa = [ rsa = [
{file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"}, {file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"},
{file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"}, {file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"},

View file

@ -15,6 +15,7 @@ torch = "^1.11.0"
torchvision = "^0.12.0" torchvision = "^0.12.0"
tqdm = "^4.64.0" tqdm = "^4.64.0"
wandb = "^0.12.19" wandb = "^0.12.19"
rich = "^12.4.4"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = "^22.3.0" black = "^22.3.0"

View file

@ -3,8 +3,8 @@ import logging
import albumentations as A import albumentations as A
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import yaml
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -13,8 +13,27 @@ from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.paste import RandomPaste from utils.paste import RandomPaste
class_labels = { CONFIG = {
1: "sphere", "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
"FEATURES": [64, 128, 256, 512],
"N_CHANNELS": 3,
"N_CLASSES": 1,
"AMP": True,
"PIN_MEMORY": True,
"BENCHMARK": True,
"DEVICE": "gpu",
"WORKERS": 8,
"EPOCHS": 5,
"BATCH_SIZE": 16,
"LEARNING_RATE": 1e-4,
"WEIGHT_DECAY": 1e-8,
"MOMENTUM": 0.9,
"IMG_SIZE": 512,
"SPHERES": 5,
} }
if __name__ == "__main__": if __name__ == "__main__":
@ -24,28 +43,7 @@ if __name__ == "__main__":
# setup wandb # setup wandb
logger = WandbLogger( logger = WandbLogger(
project="U-Net", project="U-Net",
config=dict( config=CONFIG,
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[64, 128, 256, 512],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
PIN_MEMORY=True,
BENCHMARK=True,
DEVICE="gpu",
WORKERS=8,
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8,
MOMENTUM=0.9,
IMG_SIZE=512,
SPHERES=5,
),
settings=wandb.Settings( settings=wandb.Settings(
code_dir="./src/", code_dir="./src/",
), ),
@ -55,10 +53,7 @@ if __name__ == "__main__":
pl.seed_everything(69420, workers=True) pl.seed_everything(69420, workers=True)
# 0. Create network # 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"])
# log the number of parameters of the model
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
# log gradients and weights regularly # log gradients and weights regularly
logger.watch(net, log="all") logger.watch(net, log="all")
@ -66,88 +61,59 @@ if __name__ == "__main__":
# 1. Create transforms # 1. Create transforms
tf_train = A.Compose( tf_train = A.Compose(
[ [
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE), A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
A.Flip(), A.Flip(),
A.ColorJitter(), A.ColorJitter(),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
A.GaussianBlur(), A.GaussianBlur(),
A.ISONoise(), A.ISONoise(),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
], ],
) )
tf_valid = A.Compose(
[
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
# 2. Create datasets # 2. Create datasets
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
# 2.5. Create subset, if uncommented # 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) # ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
# 3. Create data loaders # 3. Create data loaders
train_loader = DataLoader( train_loader = DataLoader(
ds_train, ds_train,
shuffle=True, shuffle=True,
batch_size=wandb.config.BATCH_SIZE, batch_size=CONFIG["BATCH_SIZE"],
num_workers=wandb.config.WORKERS, num_workers=CONFIG["WORKERS"],
pin_memory=wandb.config.PIN_MEMORY, pin_memory=CONFIG["PIN_MEMORY"],
) )
val_loader = DataLoader( val_loader = DataLoader(
ds_valid, ds_valid,
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
test_loader = DataLoader(
ds_test,
shuffle=False,
drop_last=False,
batch_size=1, batch_size=1,
num_workers=wandb.config.WORKERS, num_workers=CONFIG["WORKERS"],
pin_memory=wandb.config.PIN_MEMORY, pin_memory=CONFIG["PIN_MEMORY"],
) )
# 4. Create the trainer # 4. Create the trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=wandb.config.EPOCHS, max_epochs=CONFIG["EPOCHS"],
accelerator="gpu", accelerator=CONFIG["DEVICE"],
precision=16, # precision=16,
auto_scale_batch_size="binsearch", auto_scale_batch_size="binsearch",
benchmark=wandb.config.BENCHMARK, benchmark=CONFIG["BENCHMARK"],
val_check_interval=100, val_check_interval=100,
callbacks=RichProgressBar(),
) )
# print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
# # wandb init log
# wandb.log(
# {
# "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
# },
# commit=False,
# )
try: try:
trainer.fit( trainer.fit(
model=net, model=net,
train_dataloaders=train_loader, train_dataloaders=train_loader,
val_dataloaders=val_loader, val_dataloaders=val_loader,
test_dataloaders=test_loader,
accelerator=wandb.config.DEVICE,
) )
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")

View file

@ -1,7 +1,5 @@
""" Full assembly of the parts to form the complete network """ """ Full assembly of the parts to form the complete network """
from xmlrpc.server import list_public_methods
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
@ -40,6 +38,7 @@ class UNet(pl.LightningModule):
def forward(self, x): def forward(self, x):
skips = [] skips = []
x = x.to(self.device)
x = self.inc(x) x = self.inc(x)
for down in self.downs: for down in self.downs:
@ -53,8 +52,7 @@ class UNet(pl.LightningModule):
return x return x
@staticmethod def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key):
def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key):
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
@ -99,16 +97,17 @@ class UNet(pl.LightningModule):
accuracy = (masks_true == masks_pred_bin).float().mean() accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true) dice = dice_coeff(masks_pred_bin, masks_true)
wandb.log( self.log(
"train",
{ {
"train/accuracy": accuracy, "accuracy": accuracy,
"train/bce": loss, "bce": loss,
"train/dice": dice, "dice": dice,
"train/mae": mae, "mae": mae,
} },
) )
return loss, dice, accuracy, mae return loss # , dice, accuracy, mae
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# unpacking # unpacking
@ -119,79 +118,79 @@ class UNet(pl.LightningModule):
# compute metrics # compute metrics
loss = F.cross_entropy(masks_pred, masks_true) loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean() # accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true) # dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0: if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions") self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
return loss, dice, accuracy, mae return loss # , dice, accuracy, mae
def validation_step_end(self, validation_outputs): # def validation_step_end(self, validation_outputs):
# unpacking # # unpacking
loss, dice, accuracy, mae = validation_outputs # loss, dice, accuracy, mae = validation_outputs
optimizer = self.optimizers[0] # # optimizer = self.optimizers[0]
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
wandb.log( # wandb.log(
{ # {
"train/learning_rate": learning_rate, # # "train/learning_rate": learning_rate,
"val/accuracy": accuracy, # "val/accuracy": accuracy,
"val/bce": loss, # "val/bce": loss,
"val/dice": dice, # "val/dice": dice,
"val/mae": mae, # "val/mae": mae,
} # }
) # )
# export model to onnx # # export model to onnx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True) # dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx") # torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
artifact = wandb.Artifact("onnx", type="model") # artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model.onnx") # artifact.add_file(f"checkpoints/model.onnx")
wandb.run.log_artifact(artifact) # wandb.run.log_artifact(artifact)
def test_step(self, batch, batch_idx): # def test_step(self, batch, batch_idx):
# unpacking # # unpacking
images, masks_true = batch # images, masks_true = batch
masks_true = masks_true.unsqueeze(1) # masks_true = masks_true.unsqueeze(1)
masks_pred = self(images) # masks_pred = self(images)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() # masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
# compute metrics # # compute metrics
loss = F.cross_entropy(masks_pred, masks_true) # loss = F.cross_entropy(masks_pred, masks_true)
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true) # mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy = (masks_true == masks_pred_bin).float().mean() # accuracy = (masks_true == masks_pred_bin).float().mean()
dice = dice_coeff(masks_pred_bin, masks_true) # dice = dice_coeff(masks_pred_bin, masks_true)
if batch_idx == 0: # if batch_idx == 0:
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions") # self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
return loss, dice, accuracy, mae # return loss, dice, accuracy, mae
def test_step_end(self, test_outputs): # def test_step_end(self, test_outputs):
# unpacking # # unpacking
list_loss, list_dice, list_accuracy, list_mae = test_outputs # list_loss, list_dice, list_accuracy, list_mae = test_outputs
# averaging # # averaging
loss = np.mean(list_loss) # loss = np.mean(list_loss)
dice = np.mean(list_dice) # dice = np.mean(list_dice)
accuracy = np.mean(list_accuracy) # accuracy = np.mean(list_accuracy)
mae = np.mean(list_mae) # mae = np.mean(list_mae)
# get learning rate # # # get learning rate
optimizer = self.optimizers[0] # # optimizer = self.optimizers[0]
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] # # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
wandb.log( # wandb.log(
{ # {
"train/learning_rate": learning_rate, # # "train/learning_rate": learning_rate,
"val/accuracy": accuracy, # "test/accuracy": accuracy,
"val/bce": loss, # "test/bce": loss,
"val/dice": dice, # "test/dice": dice,
"val/mae": mae, # "test/mae": mae,
} # }
) # )
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.RMSprop( optimizer = torch.optim.RMSprop(
@ -200,10 +199,10 @@ class UNet(pl.LightningModule):
weight_decay=wandb.config.WEIGHT_DECAY, weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM, momentum=wandb.config.MOMENTUM,
) )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, # optimizer,
"max", # "max",
patience=2, # patience=2,
) # )
return optimizer, scheduler return optimizer # , scheduler