feat: validate every 100 steps

Former-commit-id: a3367de4ed56c5a708d66e7cd6be27d52bb92ccc [formerly 4606a91526eae57d56fc93df7ed34b867495e1c5]
Former-commit-id: 6584449cd25b18ddd46f6804c8b1653e1c72dda0
This commit is contained in:
Laurent Fainsin 2022-07-01 15:31:53 +02:00
parent cf8f52735a
commit 2571e5c6d3

View file

@ -137,7 +137,7 @@ if __name__ == "__main__":
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# log gradients and weights four time per epoch # log gradients and weights four time per epoch
wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4) wandb.watch(net, criterion, log_freq=100)
# print the config # print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
@ -198,92 +198,87 @@ if __name__ == "__main__":
} }
) )
# Evaluation round if step and (step % 100 == 0 or step == len(train_loader)):
net.eval() # Evaluation round
accuracy = 0 net.eval()
val_loss = 0 accuracy = 0
dice = 0 val_loss = 0
mae = 0 dice = 0
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar: mae = 0
for images, masks_true in val_loader: with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
for images, masks_true in val_loader:
# transfer images to device # transfer images to device
images = images.to(device=device) images = images.to(device=device)
masks_true = masks_true.unsqueeze(1).to(device=device) masks_true = masks_true.unsqueeze(1).to(device=device)
# forward # forward
with torch.inference_mode(): with torch.inference_mode():
masks_pred = net(images) masks_pred = net(images)
# compute metrics # compute metrics
val_loss += criterion(pred_masks, true_masks) val_loss += criterion(pred_masks, true_masks)
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks) mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
accuracy += (true_masks == pred_masks_bin).float().mean() accuracy += (true_masks == pred_masks_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true) dice += dice_coeff(masks_pred_bin, masks_true)
# update progress bar # update progress bar
pbar.update(images.shape[0]) pbar2.update(images.shape[0])
accuracy /= len(val_loader) accuracy /= len(val_loader)
val_loss /= len(val_loader) val_loss /= len(val_loader)
dice /= len(val_loader) dice /= len(val_loader)
mae /= len(val_loader) mae /= len(val_loader)
# save the last validation batch to table # save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred, pred_bin) in enumerate( for i, (img, mask, pred, pred_bin) in enumerate(
zip( zip(
images.to("cpu"), images.to("cpu"),
masks_true.to("cpu"), masks_true.to("cpu"),
masks_pred.to("cpu"), masks_pred.to("cpu"),
masks_pred_bin.to("cpu").squeeze().int().numpy(), masks_pred_bin.to("cpu").squeeze().int().numpy(),
) )
): ):
table.add_data( table.add_data(
i, i,
wandb.Image(img), wandb.Image(img),
wandb.Image(mask), wandb.Image(mask),
wandb.Image( wandb.Image(
pred, pred,
masks={ masks={
"predictions": { "predictions": {
"mask_data": pred_bin, "mask_data": pred_bin,
"class_labels": class_labels, "class_labels": class_labels,
}, },
},
),
)
# log validation metrics
wandb.log(
{
"predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice,
"val/mae": mae,
}, },
), commit=False,
) )
# log validation metrics # update hyperparameters
wandb.log( net.train()
{ scheduler.step(dice)
"predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice,
"val/mae": mae,
},
commit=False,
)
# update hyperparameters # export model to onnx format when validation ends
net.train() dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
scheduler.step(dice) torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
artifact = wandb.Artifact("onnx", type="model")
# save weights when epoch end artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth") wandb.run.log_artifact(artifact)
artifact = wandb.Artifact("pth", type="model")
artifact.add_file(f"checkpoints/model-{epoch}.pth")
wandb.run.log_artifact(artifact)
# export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file(f"checkpoints/model-{epoch}.onnx")
wandb.run.log_artifact(artifact)
# stop wandb # stop wandb
wandb.run.finish() wandb.run.finish()