feat: even more wandb logging

Former-commit-id: 1a3c28040a734ca2229e33603405054abc8e3000 [formerly 907e4f7cae3c25a84baf0eaa5ec4d03ddaea0bdb]
Former-commit-id: fdfb7dcb7d0573efbff79956e7a4bebfe26e2171
This commit is contained in:
Laurent Fainsin 2022-07-01 10:27:12 +02:00
parent 7bdac6583b
commit 2ab95734e4

View file

@ -40,6 +40,9 @@ def main():
IMG_SIZE=512, IMG_SIZE=512,
SPHERES=5, SPHERES=5,
), ),
settings=wandb.Settings(
code_dir="./src/",
),
) )
# create device # create device
@ -51,7 +54,6 @@ def main():
# 0. Create network # 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
wandb.watch(net, log_freq=100) # TODO: 1/4 epochs
# transfer network to device # transfer network to device
net.to(device=device) net.to(device=device)
@ -125,15 +127,11 @@ def main():
artifact.add_file("checkpoints/model-0.onnx") artifact.add_file("checkpoints/model-0.onnx")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# print the config # log gradients and weights four time per epoch
logging.info( wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4)
f"""wandb config:
{yaml.dump(wandb.config.as_dict())}
"""
)
# setup wandb table for saving images # print the config
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
try: try:
for epoch in range(1, wandb.config.EPOCHS + 1): for epoch in range(1, wandb.config.EPOCHS + 1):
@ -165,6 +163,7 @@ def main():
# compute metrics # compute metrics
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean() accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
# update tqdm progress bar # update tqdm progress bar
@ -174,9 +173,10 @@ def main():
# log metrics # log metrics
wandb.log( wandb.log(
{ {
"train/epoch": epoch - 1 + step / len(train_loader), "epoch": epoch - 1 + step / len(train_loader),
"train/accuracy": accuracy, "train/accuracy": accuracy,
"train/bce": train_loss, "train/bce": train_loss,
"train/dice": dice,
"train/mae": mae, "train/mae": mae,
} }
) )
@ -184,6 +184,7 @@ def main():
# Evaluation round # Evaluation round
net.eval() net.eval()
accuracy = 0 accuracy = 0
val_loss = 0
dice = 0 dice = 0
mae = 0 mae = 0
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar: with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
@ -198,19 +199,22 @@ def main():
masks_pred = net(images) masks_pred = net(images)
# compute metrics # compute metrics
val_loss += criterion(pred_masks, true_masks)
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
accuracy += (true_masks == pred_masks_bin).float().sum() accuracy += (true_masks == pred_masks_bin).float().mean()
dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False) dice += dice_coeff(masks_pred_bin, masks_true)
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum")
# update progress bar # update progress bar
pbar.update(images.shape[0]) pbar.update(images.shape[0])
accuracy /= len(ds_valid) accuracy /= len(val_loader)
dice /= len(val_loader) # TODO: fix dice_coeff to not average val_loss /= len(val_loader)
mae /= len(ds_valid) dice /= len(val_loader)
mae /= len(val_loader)
# save the last validation batch to table # save the last validation batch to table
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
for i, (img, mask, pred) in enumerate( for i, (img, mask, pred) in enumerate(
zip( zip(
images.to("cpu"), images.to("cpu"),
@ -223,11 +227,13 @@ def main():
# log validation metrics # log validation metrics
wandb.log( wandb.log(
{ {
"val/predictions": table, "predictions": table,
"val/accuracy": accuracy, "val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice, "val/dice": dice,
"val/mae": mae, "val/mae": mae,
} },
commit=False,
) )
# update hyperparameters # update hyperparameters