mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: kinda broken
Former-commit-id: 4cf02610721ba30c3dd1be6377daeeed907bc651 [formerly 52ef07ec8a123ddd362ac7c930eb6c915848e8b4] Former-commit-id: 29fc18cae50625fd1f2868fc9696ca505f5648e2
This commit is contained in:
parent
982dfe99d7
commit
e4562e2481
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,6 +4,7 @@ __pycache__/
|
||||||
|
|
||||||
wandb/
|
wandb/
|
||||||
images/
|
images/
|
||||||
|
lightning_logs/
|
||||||
|
|
||||||
checkpoints/
|
checkpoints/
|
||||||
*.pth
|
*.pth
|
||||||
|
|
39
poetry.lock
generated
39
poetry.lock
generated
|
@ -189,6 +189,17 @@ category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "commonmark"
|
||||||
|
version = "0.9.1"
|
||||||
|
description = "Python parser for the CommonMark Markdown spec"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cycler"
|
name = "cycler"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
@ -881,7 +892,7 @@ python-versions = ">=3.6"
|
||||||
name = "pygments"
|
name = "pygments"
|
||||||
version = "2.12.0"
|
version = "2.12.0"
|
||||||
description = "Pygments is a syntax highlighting package written in Python."
|
description = "Pygments is a syntax highlighting package written in Python."
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
@ -1027,6 +1038,22 @@ requests = ">=2.0.0"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "12.4.4"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6.3,<4.0.0"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
commonmark = ">=0.9.0,<0.10.0"
|
||||||
|
pygments = ">=2.6.0,<3.0.0"
|
||||||
|
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rsa"
|
name = "rsa"
|
||||||
version = "4.8"
|
version = "4.8"
|
||||||
|
@ -1446,7 +1473,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = ">=3.8,<3.11"
|
python-versions = ">=3.8,<3.11"
|
||||||
content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5"
|
content-hash = "416650c968a0021f7d64028f272464d96319c361a72888ae4cb3e2a602873832"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
absl-py = [
|
absl-py = [
|
||||||
|
@ -1652,6 +1679,10 @@ colorama = [
|
||||||
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
|
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
|
||||||
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
|
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
|
||||||
]
|
]
|
||||||
|
commonmark = [
|
||||||
|
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
|
||||||
|
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
|
||||||
|
]
|
||||||
cycler = [
|
cycler = [
|
||||||
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
||||||
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
||||||
|
@ -2419,6 +2450,10 @@ requests-oauthlib = [
|
||||||
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
||||||
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
||||||
]
|
]
|
||||||
|
rich = [
|
||||||
|
{file = "rich-12.4.4-py3-none-any.whl", hash = "sha256:d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec"},
|
||||||
|
{file = "rich-12.4.4.tar.gz", hash = "sha256:4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2"},
|
||||||
|
]
|
||||||
rsa = [
|
rsa = [
|
||||||
{file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"},
|
{file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"},
|
||||||
{file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"},
|
{file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"},
|
||||||
|
|
|
@ -15,6 +15,7 @@ torch = "^1.11.0"
|
||||||
torchvision = "^0.12.0"
|
torchvision = "^0.12.0"
|
||||||
tqdm = "^4.64.0"
|
tqdm = "^4.64.0"
|
||||||
wandb = "^0.12.19"
|
wandb = "^0.12.19"
|
||||||
|
rich = "^12.4.4"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
|
|
114
src/train.py
114
src/train.py
|
@ -3,8 +3,8 @@ import logging
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
from pytorch_lightning.callbacks import RichProgressBar
|
||||||
from pytorch_lightning.loggers import WandbLogger
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
@ -13,8 +13,27 @@ from src.utils.dataset import SphereDataset
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
class_labels = {
|
CONFIG = {
|
||||||
1: "sphere",
|
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||||
|
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
|
||||||
|
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
||||||
|
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
|
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
|
"FEATURES": [64, 128, 256, 512],
|
||||||
|
"N_CHANNELS": 3,
|
||||||
|
"N_CLASSES": 1,
|
||||||
|
"AMP": True,
|
||||||
|
"PIN_MEMORY": True,
|
||||||
|
"BENCHMARK": True,
|
||||||
|
"DEVICE": "gpu",
|
||||||
|
"WORKERS": 8,
|
||||||
|
"EPOCHS": 5,
|
||||||
|
"BATCH_SIZE": 16,
|
||||||
|
"LEARNING_RATE": 1e-4,
|
||||||
|
"WEIGHT_DECAY": 1e-8,
|
||||||
|
"MOMENTUM": 0.9,
|
||||||
|
"IMG_SIZE": 512,
|
||||||
|
"SPHERES": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -24,28 +43,7 @@ if __name__ == "__main__":
|
||||||
# setup wandb
|
# setup wandb
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
project="U-Net",
|
project="U-Net",
|
||||||
config=dict(
|
config=CONFIG,
|
||||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
|
||||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
|
|
||||||
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
|
|
||||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
|
||||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
|
||||||
FEATURES=[64, 128, 256, 512],
|
|
||||||
N_CHANNELS=3,
|
|
||||||
N_CLASSES=1,
|
|
||||||
AMP=True,
|
|
||||||
PIN_MEMORY=True,
|
|
||||||
BENCHMARK=True,
|
|
||||||
DEVICE="gpu",
|
|
||||||
WORKERS=8,
|
|
||||||
EPOCHS=5,
|
|
||||||
BATCH_SIZE=16,
|
|
||||||
LEARNING_RATE=1e-4,
|
|
||||||
WEIGHT_DECAY=1e-8,
|
|
||||||
MOMENTUM=0.9,
|
|
||||||
IMG_SIZE=512,
|
|
||||||
SPHERES=5,
|
|
||||||
),
|
|
||||||
settings=wandb.Settings(
|
settings=wandb.Settings(
|
||||||
code_dir="./src/",
|
code_dir="./src/",
|
||||||
),
|
),
|
||||||
|
@ -55,10 +53,7 @@ if __name__ == "__main__":
|
||||||
pl.seed_everything(69420, workers=True)
|
pl.seed_everything(69420, workers=True)
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"])
|
||||||
|
|
||||||
# log the number of parameters of the model
|
|
||||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
# log gradients and weights regularly
|
# log gradients and weights regularly
|
||||||
logger.watch(net, log="all")
|
logger.watch(net, log="all")
|
||||||
|
@ -66,88 +61,59 @@ if __name__ == "__main__":
|
||||||
# 1. Create transforms
|
# 1. Create transforms
|
||||||
tf_train = A.Compose(
|
tf_train = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
A.ColorJitter(),
|
A.ColorJitter(),
|
||||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
|
||||||
A.GaussianBlur(),
|
A.GaussianBlur(),
|
||||||
A.ISONoise(),
|
A.ISONoise(),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
tf_valid = A.Compose(
|
|
||||||
[
|
|
||||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
|
||||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
|
||||||
A.ToFloat(max_value=255),
|
|
||||||
ToTensorV2(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Create datasets
|
# 2. Create datasets
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
|
||||||
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
|
||||||
|
|
||||||
# 2.5. Create subset, if uncommented
|
# 2.5. Create subset, if uncommented
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
|
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||||
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
ds_train,
|
ds_train,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
batch_size=wandb.config.BATCH_SIZE,
|
batch_size=CONFIG["BATCH_SIZE"],
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=CONFIG["WORKERS"],
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=CONFIG["PIN_MEMORY"],
|
||||||
)
|
)
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
ds_valid,
|
ds_valid,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
batch_size=wandb.config.BATCH_SIZE,
|
|
||||||
num_workers=wandb.config.WORKERS,
|
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
|
||||||
)
|
|
||||||
test_loader = DataLoader(
|
|
||||||
ds_test,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=CONFIG["WORKERS"],
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=CONFIG["PIN_MEMORY"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Create the trainer
|
# 4. Create the trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=wandb.config.EPOCHS,
|
max_epochs=CONFIG["EPOCHS"],
|
||||||
accelerator="gpu",
|
accelerator=CONFIG["DEVICE"],
|
||||||
precision=16,
|
# precision=16,
|
||||||
auto_scale_batch_size="binsearch",
|
auto_scale_batch_size="binsearch",
|
||||||
benchmark=wandb.config.BENCHMARK,
|
benchmark=CONFIG["BENCHMARK"],
|
||||||
val_check_interval=100,
|
val_check_interval=100,
|
||||||
|
callbacks=RichProgressBar(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# print the config
|
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
|
||||||
|
|
||||||
# # wandb init log
|
|
||||||
# wandb.log(
|
|
||||||
# {
|
|
||||||
# "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
|
||||||
# },
|
|
||||||
# commit=False,
|
|
||||||
# )
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
model=net,
|
model=net,
|
||||||
train_dataloaders=train_loader,
|
train_dataloaders=train_loader,
|
||||||
val_dataloaders=val_loader,
|
val_dataloaders=val_loader,
|
||||||
test_dataloaders=test_loader,
|
|
||||||
accelerator=wandb.config.DEVICE,
|
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
""" Full assembly of the parts to form the complete network """
|
""" Full assembly of the parts to form the complete network """
|
||||||
|
|
||||||
from xmlrpc.server import list_public_methods
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
@ -40,6 +38,7 @@ class UNet(pl.LightningModule):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
skips = []
|
skips = []
|
||||||
|
|
||||||
|
x = x.to(self.device)
|
||||||
x = self.inc(x)
|
x = self.inc(x)
|
||||||
|
|
||||||
for down in self.downs:
|
for down in self.downs:
|
||||||
|
@ -53,8 +52,7 @@ class UNet(pl.LightningModule):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key):
|
||||||
def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key):
|
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
|
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
@ -99,16 +97,17 @@ class UNet(pl.LightningModule):
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
wandb.log(
|
self.log(
|
||||||
|
"train",
|
||||||
{
|
{
|
||||||
"train/accuracy": accuracy,
|
"accuracy": accuracy,
|
||||||
"train/bce": loss,
|
"bce": loss,
|
||||||
"train/dice": dice,
|
"dice": dice,
|
||||||
"train/mae": mae,
|
"mae": mae,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss, dice, accuracy, mae
|
return loss # , dice, accuracy, mae
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
# unpacking
|
# unpacking
|
||||||
|
@ -119,79 +118,79 @@ class UNet(pl.LightningModule):
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
loss = F.cross_entropy(masks_pred, masks_true)
|
loss = F.cross_entropy(masks_pred, masks_true)
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
# dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
if batch_idx == 0:
|
if batch_idx == 0:
|
||||||
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
|
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
|
||||||
|
|
||||||
return loss, dice, accuracy, mae
|
return loss # , dice, accuracy, mae
|
||||||
|
|
||||||
def validation_step_end(self, validation_outputs):
|
# def validation_step_end(self, validation_outputs):
|
||||||
# unpacking
|
# # unpacking
|
||||||
loss, dice, accuracy, mae = validation_outputs
|
# loss, dice, accuracy, mae = validation_outputs
|
||||||
optimizer = self.optimizers[0]
|
# # optimizer = self.optimizers[0]
|
||||||
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||||
|
|
||||||
wandb.log(
|
# wandb.log(
|
||||||
{
|
# {
|
||||||
"train/learning_rate": learning_rate,
|
# # "train/learning_rate": learning_rate,
|
||||||
"val/accuracy": accuracy,
|
# "val/accuracy": accuracy,
|
||||||
"val/bce": loss,
|
# "val/bce": loss,
|
||||||
"val/dice": dice,
|
# "val/dice": dice,
|
||||||
"val/mae": mae,
|
# "val/mae": mae,
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
# export model to onnx
|
# # export model to onnx
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
# dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
# torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
# artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file(f"checkpoints/model.onnx")
|
# artifact.add_file(f"checkpoints/model.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
# wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
# def test_step(self, batch, batch_idx):
|
||||||
# unpacking
|
# # unpacking
|
||||||
images, masks_true = batch
|
# images, masks_true = batch
|
||||||
masks_true = masks_true.unsqueeze(1)
|
# masks_true = masks_true.unsqueeze(1)
|
||||||
masks_pred = self(images)
|
# masks_pred = self(images)
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
# masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
|
||||||
# compute metrics
|
# # compute metrics
|
||||||
loss = F.cross_entropy(masks_pred, masks_true)
|
# loss = F.cross_entropy(masks_pred, masks_true)
|
||||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
# dice = dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
if batch_idx == 0:
|
# if batch_idx == 0:
|
||||||
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
# self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
||||||
|
|
||||||
return loss, dice, accuracy, mae
|
# return loss, dice, accuracy, mae
|
||||||
|
|
||||||
def test_step_end(self, test_outputs):
|
# def test_step_end(self, test_outputs):
|
||||||
# unpacking
|
# # unpacking
|
||||||
list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
# list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
||||||
|
|
||||||
# averaging
|
# # averaging
|
||||||
loss = np.mean(list_loss)
|
# loss = np.mean(list_loss)
|
||||||
dice = np.mean(list_dice)
|
# dice = np.mean(list_dice)
|
||||||
accuracy = np.mean(list_accuracy)
|
# accuracy = np.mean(list_accuracy)
|
||||||
mae = np.mean(list_mae)
|
# mae = np.mean(list_mae)
|
||||||
|
|
||||||
# get learning rate
|
# # # get learning rate
|
||||||
optimizer = self.optimizers[0]
|
# # optimizer = self.optimizers[0]
|
||||||
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||||
|
|
||||||
wandb.log(
|
# wandb.log(
|
||||||
{
|
# {
|
||||||
"train/learning_rate": learning_rate,
|
# # "train/learning_rate": learning_rate,
|
||||||
"val/accuracy": accuracy,
|
# "test/accuracy": accuracy,
|
||||||
"val/bce": loss,
|
# "test/bce": loss,
|
||||||
"val/dice": dice,
|
# "test/dice": dice,
|
||||||
"val/mae": mae,
|
# "test/mae": mae,
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.RMSprop(
|
optimizer = torch.optim.RMSprop(
|
||||||
|
@ -200,10 +199,10 @@ class UNet(pl.LightningModule):
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
momentum=wandb.config.MOMENTUM,
|
momentum=wandb.config.MOMENTUM,
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer,
|
# optimizer,
|
||||||
"max",
|
# "max",
|
||||||
patience=2,
|
# patience=2,
|
||||||
)
|
# )
|
||||||
|
|
||||||
return optimizer, scheduler
|
return optimizer # , scheduler
|
||||||
|
|
Loading…
Reference in a new issue