mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
revert: want back to table display for predictions
Former-commit-id: 2af16f486ffeaca232aee35d360fe75e78adfc9e [formerly e072ee4b33587b45f6c6c34f809b9638a9e0b569] Former-commit-id: 69c874da0fbe6e2a5cb2421fa20ffb9c34b36fc1
This commit is contained in:
parent
dc4a399c0f
commit
f4ed2f799e
|
@ -34,28 +34,10 @@ def evaluate(net, dataloader, device):
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
# save some images to wandb
|
# save some images to wandb
|
||||||
overlays = []
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")):
|
for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))):
|
||||||
mask_img = np.asarray(mask > 0.5, np.uint8).squeeze(0) # tester des trucs sans le threshold
|
table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred))
|
||||||
pred_img = np.asarray(pred > 0.5, np.uint8).squeeze(0)
|
wandb.log({"predictions_table": table})
|
||||||
|
|
||||||
overlays.append(
|
|
||||||
wandb.Image(
|
|
||||||
img,
|
|
||||||
masks={
|
|
||||||
"ground_truth": {
|
|
||||||
"mask_data": mask_img,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
"predictions": {
|
|
||||||
"mask_data": pred_img,
|
|
||||||
"class_labels": class_labels,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
wandb.log({"val/images": overlays})
|
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue