mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: kinda broken
Former-commit-id: 4cf02610721ba30c3dd1be6377daeeed907bc651 [formerly 52ef07ec8a123ddd362ac7c930eb6c915848e8b4] Former-commit-id: 29fc18cae50625fd1f2868fc9696ca505f5648e2
This commit is contained in:
parent
982dfe99d7
commit
e4562e2481
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,6 +4,7 @@ __pycache__/
|
|||
|
||||
wandb/
|
||||
images/
|
||||
lightning_logs/
|
||||
|
||||
checkpoints/
|
||||
*.pth
|
||||
|
|
39
poetry.lock
generated
39
poetry.lock
generated
|
@ -189,6 +189,17 @@ category = "main"
|
|||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
[[package]]
|
||||
name = "commonmark"
|
||||
version = "0.9.1"
|
||||
description = "Python parser for the CommonMark Markdown spec"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.extras]
|
||||
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "cycler"
|
||||
version = "0.11.0"
|
||||
|
@ -881,7 +892,7 @@ python-versions = ">=3.6"
|
|||
name = "pygments"
|
||||
version = "2.12.0"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
|
||||
|
@ -1027,6 +1038,22 @@ requests = ">=2.0.0"
|
|||
[package.extras]
|
||||
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "12.4.4"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6.3,<4.0.0"
|
||||
|
||||
[package.dependencies]
|
||||
commonmark = ">=0.9.0,<0.10.0"
|
||||
pygments = ">=2.6.0,<3.0.0"
|
||||
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
|
||||
|
||||
[package.extras]
|
||||
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "4.8"
|
||||
|
@ -1446,7 +1473,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8,<3.11"
|
||||
content-hash = "b192d0e5f593e99630bb92cd31c510dcdea67b0b54861176f92f50505724e7d5"
|
||||
content-hash = "416650c968a0021f7d64028f272464d96319c361a72888ae4cb3e2a602873832"
|
||||
|
||||
[metadata.files]
|
||||
absl-py = [
|
||||
|
@ -1652,6 +1679,10 @@ colorama = [
|
|||
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
|
||||
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
|
||||
]
|
||||
commonmark = [
|
||||
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
|
||||
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
|
||||
]
|
||||
cycler = [
|
||||
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
||||
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
||||
|
@ -2419,6 +2450,10 @@ requests-oauthlib = [
|
|||
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
||||
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
||||
]
|
||||
rich = [
|
||||
{file = "rich-12.4.4-py3-none-any.whl", hash = "sha256:d2bbd99c320a2532ac71ff6a3164867884357da3e3301f0240090c5d2fdac7ec"},
|
||||
{file = "rich-12.4.4.tar.gz", hash = "sha256:4c586de507202505346f3e32d1363eb9ed6932f0c2f63184dea88983ff4971e2"},
|
||||
]
|
||||
rsa = [
|
||||
{file = "rsa-4.8-py3-none-any.whl", hash = "sha256:95c5d300c4e879ee69708c428ba566c59478fd653cc3a22243eeb8ed846950bb"},
|
||||
{file = "rsa-4.8.tar.gz", hash = "sha256:5c6bd9dc7a543b7fe4304a631f8a8a3b674e2bbfc49c2ae96200cdbe55df6b17"},
|
||||
|
|
|
@ -15,6 +15,7 @@ torch = "^1.11.0"
|
|||
torchvision = "^0.12.0"
|
||||
tqdm = "^4.64.0"
|
||||
wandb = "^0.12.19"
|
||||
rich = "^12.4.4"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3.0"
|
||||
|
|
114
src/train.py
114
src/train.py
|
@ -3,8 +3,8 @@ import logging
|
|||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import yaml
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -13,8 +13,27 @@ from src.utils.dataset import SphereDataset
|
|||
from unet import UNet
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
CONFIG = {
|
||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
|
||||
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
||||
"DIR_SPHERE_IMG": "/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
"DIR_SPHERE_MASK": "/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
"FEATURES": [64, 128, 256, 512],
|
||||
"N_CHANNELS": 3,
|
||||
"N_CLASSES": 1,
|
||||
"AMP": True,
|
||||
"PIN_MEMORY": True,
|
||||
"BENCHMARK": True,
|
||||
"DEVICE": "gpu",
|
||||
"WORKERS": 8,
|
||||
"EPOCHS": 5,
|
||||
"BATCH_SIZE": 16,
|
||||
"LEARNING_RATE": 1e-4,
|
||||
"WEIGHT_DECAY": 1e-8,
|
||||
"MOMENTUM": 0.9,
|
||||
"IMG_SIZE": 512,
|
||||
"SPHERES": 5,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,28 +43,7 @@ if __name__ == "__main__":
|
|||
# setup wandb
|
||||
logger = WandbLogger(
|
||||
project="U-Net",
|
||||
config=dict(
|
||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
|
||||
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
FEATURES=[64, 128, 256, 512],
|
||||
N_CHANNELS=3,
|
||||
N_CLASSES=1,
|
||||
AMP=True,
|
||||
PIN_MEMORY=True,
|
||||
BENCHMARK=True,
|
||||
DEVICE="gpu",
|
||||
WORKERS=8,
|
||||
EPOCHS=5,
|
||||
BATCH_SIZE=16,
|
||||
LEARNING_RATE=1e-4,
|
||||
WEIGHT_DECAY=1e-8,
|
||||
MOMENTUM=0.9,
|
||||
IMG_SIZE=512,
|
||||
SPHERES=5,
|
||||
),
|
||||
config=CONFIG,
|
||||
settings=wandb.Settings(
|
||||
code_dir="./src/",
|
||||
),
|
||||
|
@ -55,10 +53,7 @@ if __name__ == "__main__":
|
|||
pl.seed_everything(69420, workers=True)
|
||||
|
||||
# 0. Create network
|
||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||
|
||||
# log the number of parameters of the model
|
||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
net = UNet(n_channels=CONFIG["N_CHANNELS"], n_classes=CONFIG["N_CLASSES"], features=CONFIG["FEATURES"])
|
||||
|
||||
# log gradients and weights regularly
|
||||
logger.watch(net, log="all")
|
||||
|
@ -66,88 +61,59 @@ if __name__ == "__main__":
|
|||
# 1. Create transforms
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
A.Resize(CONFIG["IMG_SIZE"], CONFIG["IMG_SIZE"]),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||
RandomPaste(CONFIG["SPHERES"], CONFIG["DIR_SPHERE_IMG"], CONFIG["DIR_SPHERE_MASK"]),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
tf_valid = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||
ds_train = SphereDataset(image_dir=CONFIG["DIR_TRAIN_IMG"], transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=CONFIG["DIR_TEST_IMG"])
|
||||
|
||||
# 2.5. Create subset, if uncommented
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
|
||||
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||
# ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||
# ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = DataLoader(
|
||||
ds_train,
|
||||
shuffle=True,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
num_workers=CONFIG["WORKERS"],
|
||||
pin_memory=CONFIG["PIN_MEMORY"],
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
ds_test,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
num_workers=CONFIG["WORKERS"],
|
||||
pin_memory=CONFIG["PIN_MEMORY"],
|
||||
)
|
||||
|
||||
# 4. Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=wandb.config.EPOCHS,
|
||||
accelerator="gpu",
|
||||
precision=16,
|
||||
max_epochs=CONFIG["EPOCHS"],
|
||||
accelerator=CONFIG["DEVICE"],
|
||||
# precision=16,
|
||||
auto_scale_batch_size="binsearch",
|
||||
benchmark=wandb.config.BENCHMARK,
|
||||
benchmark=CONFIG["BENCHMARK"],
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
)
|
||||
|
||||
# print the config
|
||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||
|
||||
# # wandb init log
|
||||
# wandb.log(
|
||||
# {
|
||||
# "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||
# },
|
||||
# commit=False,
|
||||
# )
|
||||
|
||||
try:
|
||||
trainer.fit(
|
||||
model=net,
|
||||
train_dataloaders=train_loader,
|
||||
val_dataloaders=val_loader,
|
||||
test_dataloaders=test_loader,
|
||||
accelerator=wandb.config.DEVICE,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
""" Full assembly of the parts to form the complete network """
|
||||
|
||||
from xmlrpc.server import list_public_methods
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
@ -40,6 +38,7 @@ class UNet(pl.LightningModule):
|
|||
def forward(self, x):
|
||||
skips = []
|
||||
|
||||
x = x.to(self.device)
|
||||
x = self.inc(x)
|
||||
|
||||
for down in self.downs:
|
||||
|
@ -53,8 +52,7 @@ class UNet(pl.LightningModule):
|
|||
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def save_to_table(images, masks_true, masks_pred, masks_pred_bin, log_key):
|
||||
def save_to_table(self, images, masks_true, masks_pred, masks_pred_bin, log_key):
|
||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
|
@ -99,16 +97,17 @@ class UNet(pl.LightningModule):
|
|||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
wandb.log(
|
||||
self.log(
|
||||
"train",
|
||||
{
|
||||
"train/accuracy": accuracy,
|
||||
"train/bce": loss,
|
||||
"train/dice": dice,
|
||||
"train/mae": mae,
|
||||
}
|
||||
"accuracy": accuracy,
|
||||
"bce": loss,
|
||||
"dice": dice,
|
||||
"mae": mae,
|
||||
},
|
||||
)
|
||||
|
||||
return loss, dice, accuracy, mae
|
||||
return loss # , dice, accuracy, mae
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
|
@ -119,79 +118,79 @@ class UNet(pl.LightningModule):
|
|||
|
||||
# compute metrics
|
||||
loss = F.cross_entropy(masks_pred, masks_true)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
# dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
if batch_idx == 0:
|
||||
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "val/predictions")
|
||||
|
||||
return loss, dice, accuracy, mae
|
||||
return loss # , dice, accuracy, mae
|
||||
|
||||
def validation_step_end(self, validation_outputs):
|
||||
# unpacking
|
||||
loss, dice, accuracy, mae = validation_outputs
|
||||
optimizer = self.optimizers[0]
|
||||
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
# def validation_step_end(self, validation_outputs):
|
||||
# # unpacking
|
||||
# loss, dice, accuracy, mae = validation_outputs
|
||||
# # optimizer = self.optimizers[0]
|
||||
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"train/learning_rate": learning_rate,
|
||||
"val/accuracy": accuracy,
|
||||
"val/bce": loss,
|
||||
"val/dice": dice,
|
||||
"val/mae": mae,
|
||||
}
|
||||
)
|
||||
# wandb.log(
|
||||
# {
|
||||
# # "train/learning_rate": learning_rate,
|
||||
# "val/accuracy": accuracy,
|
||||
# "val/bce": loss,
|
||||
# "val/dice": dice,
|
||||
# "val/mae": mae,
|
||||
# }
|
||||
# )
|
||||
|
||||
# export model to onnx
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file(f"checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
# # export model to onnx
|
||||
# dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||
# torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||
# artifact = wandb.Artifact("onnx", type="model")
|
||||
# artifact.add_file(f"checkpoints/model.onnx")
|
||||
# wandb.run.log_artifact(artifact)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
images, masks_true = batch
|
||||
masks_true = masks_true.unsqueeze(1)
|
||||
masks_pred = self(images)
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
# def test_step(self, batch, batch_idx):
|
||||
# # unpacking
|
||||
# images, masks_true = batch
|
||||
# masks_true = masks_true.unsqueeze(1)
|
||||
# masks_pred = self(images)
|
||||
# masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
|
||||
# compute metrics
|
||||
loss = F.cross_entropy(masks_pred, masks_true)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
# # compute metrics
|
||||
# loss = F.cross_entropy(masks_pred, masks_true)
|
||||
# mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
# accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
# dice = dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
if batch_idx == 0:
|
||||
self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
||||
# if batch_idx == 0:
|
||||
# self.save_to_table(images, masks_true, masks_pred, masks_pred_bin, "test/predictions")
|
||||
|
||||
return loss, dice, accuracy, mae
|
||||
# return loss, dice, accuracy, mae
|
||||
|
||||
def test_step_end(self, test_outputs):
|
||||
# unpacking
|
||||
list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
||||
# def test_step_end(self, test_outputs):
|
||||
# # unpacking
|
||||
# list_loss, list_dice, list_accuracy, list_mae = test_outputs
|
||||
|
||||
# averaging
|
||||
loss = np.mean(list_loss)
|
||||
dice = np.mean(list_dice)
|
||||
accuracy = np.mean(list_accuracy)
|
||||
mae = np.mean(list_mae)
|
||||
# # averaging
|
||||
# loss = np.mean(list_loss)
|
||||
# dice = np.mean(list_dice)
|
||||
# accuracy = np.mean(list_accuracy)
|
||||
# mae = np.mean(list_mae)
|
||||
|
||||
# get learning rate
|
||||
optimizer = self.optimizers[0]
|
||||
learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
# # # get learning rate
|
||||
# # optimizer = self.optimizers[0]
|
||||
# # learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"train/learning_rate": learning_rate,
|
||||
"val/accuracy": accuracy,
|
||||
"val/bce": loss,
|
||||
"val/dice": dice,
|
||||
"val/mae": mae,
|
||||
}
|
||||
)
|
||||
# wandb.log(
|
||||
# {
|
||||
# # "train/learning_rate": learning_rate,
|
||||
# "test/accuracy": accuracy,
|
||||
# "test/bce": loss,
|
||||
# "test/dice": dice,
|
||||
# "test/mae": mae,
|
||||
# }
|
||||
# )
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
|
@ -200,10 +199,10 @@ class UNet(pl.LightningModule):
|
|||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
"max",
|
||||
patience=2,
|
||||
)
|
||||
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
# optimizer,
|
||||
# "max",
|
||||
# patience=2,
|
||||
# )
|
||||
|
||||
return optimizer, scheduler
|
||||
return optimizer # , scheduler
|
||||
|
|
Loading…
Reference in a new issue