feat: validate every 100 steps
Former-commit-id: a3367de4ed56c5a708d66e7cd6be27d52bb92ccc [formerly 4606a91526eae57d56fc93df7ed34b867495e1c5] Former-commit-id: 6584449cd25b18ddd46f6804c8b1653e1c72dda0
This commit is contained in:
parent
cf8f52735a
commit
2571e5c6d3
19
src/train.py
19
src/train.py
|
@ -137,7 +137,7 @@ if __name__ == "__main__":
|
|||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# log gradients and weights four time per epoch
|
||||
wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4)
|
||||
wandb.watch(net, criterion, log_freq=100)
|
||||
|
||||
# print the config
|
||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||
|
@ -198,13 +198,14 @@ if __name__ == "__main__":
|
|||
}
|
||||
)
|
||||
|
||||
if step and (step % 100 == 0 or step == len(train_loader)):
|
||||
# Evaluation round
|
||||
net.eval()
|
||||
accuracy = 0
|
||||
val_loss = 0
|
||||
dice = 0
|
||||
mae = 0
|
||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
|
||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
|
||||
for images, masks_true in val_loader:
|
||||
|
||||
# transfer images to device
|
||||
|
@ -223,7 +224,7 @@ if __name__ == "__main__":
|
|||
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(images.shape[0])
|
||||
pbar2.update(images.shape[0])
|
||||
|
||||
accuracy /= len(val_loader)
|
||||
val_loss /= len(val_loader)
|
||||
|
@ -272,17 +273,11 @@ if __name__ == "__main__":
|
|||
net.train()
|
||||
scheduler.step(dice)
|
||||
|
||||
# save weights when epoch end
|
||||
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file(f"checkpoints/model-{epoch}.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# export model to onnx format
|
||||
# export model to onnx format when validation ends
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
|
||||
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file(f"checkpoints/model-{epoch}.onnx")
|
||||
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# stop wandb
|
||||
|
|
Loading…
Reference in a new issue