mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: display binary mask on top of prediction
Former-commit-id: a81b30919c6ef0998617822fd57e54209dcb2cfa [formerly 9f9cf75289c5f731dc6c93f86dff237aed88722c] Former-commit-id: d7272aefcf16e2e22081811c0452d7cb7e0980f3
This commit is contained in:
parent
31ceb97996
commit
beafa768f7
22
src/train.py
22
src/train.py
|
@ -13,6 +13,10 @@ from unet import UNet
|
||||||
from utils.dice import dice_coeff
|
from utils.dice import dice_coeff
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
|
class_labels = {
|
||||||
|
1: "sphere",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# setup logging
|
# setup logging
|
||||||
|
@ -215,14 +219,28 @@ def main():
|
||||||
|
|
||||||
# save the last validation batch to table
|
# save the last validation batch to table
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
for i, (img, mask, pred) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.to("cpu"),
|
images.to("cpu"),
|
||||||
masks_true.to("cpu"),
|
masks_true.to("cpu"),
|
||||||
masks_pred.to("cpu"),
|
masks_pred.to("cpu"),
|
||||||
|
masks_pred_bin.to("cpu").squeeze().int().numpy(),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
table.add_data(
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# log validation metrics
|
# log validation metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
|
|
Loading…
Reference in a new issue