mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: validate every 100 steps
Former-commit-id: a3367de4ed56c5a708d66e7cd6be27d52bb92ccc [formerly 4606a91526eae57d56fc93df7ed34b867495e1c5] Former-commit-id: 6584449cd25b18ddd46f6804c8b1653e1c72dda0
This commit is contained in:
parent
cf8f52735a
commit
2571e5c6d3
151
src/train.py
151
src/train.py
|
@ -137,7 +137,7 @@ if __name__ == "__main__":
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# log gradients and weights four time per epoch
|
# log gradients and weights four time per epoch
|
||||||
wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4)
|
wandb.watch(net, criterion, log_freq=100)
|
||||||
|
|
||||||
# print the config
|
# print the config
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
|
@ -198,92 +198,87 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluation round
|
if step and (step % 100 == 0 or step == len(train_loader)):
|
||||||
net.eval()
|
# Evaluation round
|
||||||
accuracy = 0
|
net.eval()
|
||||||
val_loss = 0
|
accuracy = 0
|
||||||
dice = 0
|
val_loss = 0
|
||||||
mae = 0
|
dice = 0
|
||||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
|
mae = 0
|
||||||
for images, masks_true in val_loader:
|
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
|
||||||
|
for images, masks_true in val_loader:
|
||||||
|
|
||||||
# transfer images to device
|
# transfer images to device
|
||||||
images = images.to(device=device)
|
images = images.to(device=device)
|
||||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
masks_pred = net(images)
|
masks_pred = net(images)
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
val_loss += criterion(pred_masks, true_masks)
|
val_loss += criterion(pred_masks, true_masks)
|
||||||
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
accuracy += (true_masks == pred_masks_bin).float().mean()
|
accuracy += (true_masks == pred_masks_bin).float().mean()
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar2.update(images.shape[0])
|
||||||
|
|
||||||
accuracy /= len(val_loader)
|
accuracy /= len(val_loader)
|
||||||
val_loss /= len(val_loader)
|
val_loss /= len(val_loader)
|
||||||
dice /= len(val_loader)
|
dice /= len(val_loader)
|
||||||
mae /= len(val_loader)
|
mae /= len(val_loader)
|
||||||
|
|
||||||
# save the last validation batch to table
|
# save the last validation batch to table
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.to("cpu"),
|
images.to("cpu"),
|
||||||
masks_true.to("cpu"),
|
masks_true.to("cpu"),
|
||||||
masks_pred.to("cpu"),
|
masks_pred.to("cpu"),
|
||||||
masks_pred_bin.to("cpu").squeeze().int().numpy(),
|
masks_pred_bin.to("cpu").squeeze().int().numpy(),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
table.add_data(
|
table.add_data(
|
||||||
i,
|
i,
|
||||||
wandb.Image(img),
|
wandb.Image(img),
|
||||||
wandb.Image(mask),
|
wandb.Image(mask),
|
||||||
wandb.Image(
|
wandb.Image(
|
||||||
pred,
|
pred,
|
||||||
masks={
|
masks={
|
||||||
"predictions": {
|
"predictions": {
|
||||||
"mask_data": pred_bin,
|
"mask_data": pred_bin,
|
||||||
"class_labels": class_labels,
|
"class_labels": class_labels,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# log validation metrics
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"predictions": table,
|
||||||
|
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||||
|
"val/accuracy": accuracy,
|
||||||
|
"val/bce": val_loss,
|
||||||
|
"val/dice": dice,
|
||||||
|
"val/mae": mae,
|
||||||
},
|
},
|
||||||
),
|
commit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# log validation metrics
|
# update hyperparameters
|
||||||
wandb.log(
|
net.train()
|
||||||
{
|
scheduler.step(dice)
|
||||||
"predictions": table,
|
|
||||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
|
||||||
"val/accuracy": accuracy,
|
|
||||||
"val/bce": val_loss,
|
|
||||||
"val/dice": dice,
|
|
||||||
"val/mae": mae,
|
|
||||||
},
|
|
||||||
commit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# update hyperparameters
|
# export model to onnx format when validation ends
|
||||||
net.train()
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
scheduler.step(dice)
|
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
|
||||||
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
# save weights when epoch end
|
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
||||||
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
wandb.run.log_artifact(artifact)
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
|
||||||
artifact.add_file(f"checkpoints/model-{epoch}.pth")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
# export model to onnx format
|
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
|
||||||
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
|
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
|
||||||
artifact.add_file(f"checkpoints/model-{epoch}.onnx")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
Loading…
Reference in a new issue