diff --git a/src/train.py b/src/train.py index 8d4dab1..de2fd48 100644 --- a/src/train.py +++ b/src/train.py @@ -40,6 +40,9 @@ def main(): IMG_SIZE=512, SPHERES=5, ), + settings=wandb.Settings( + code_dir="./src/", + ), ) # create device @@ -51,7 +54,6 @@ def main(): # 0. Create network net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad) - wandb.watch(net, log_freq=100) # TODO: 1/4 epochs # transfer network to device net.to(device=device) @@ -125,15 +127,11 @@ def main(): artifact.add_file("checkpoints/model-0.onnx") wandb.run.log_artifact(artifact) - # print the config - logging.info( - f"""wandb config: - {yaml.dump(wandb.config.as_dict())} - """ - ) + # log gradients and weights four time per epoch + wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4) - # setup wandb table for saving images - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) + # print the config + logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") try: for epoch in range(1, wandb.config.EPOCHS + 1): @@ -165,6 +163,7 @@ def main(): # compute metrics pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() accuracy = (true_masks == pred_masks_bin).float().mean() + dice = dice_coeff(pred_masks_bin, true_masks) mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks) # update tqdm progress bar @@ -174,9 +173,10 @@ def main(): # log metrics wandb.log( { - "train/epoch": epoch - 1 + step / len(train_loader), + "epoch": epoch - 1 + step / len(train_loader), "train/accuracy": accuracy, "train/bce": train_loss, + "train/dice": dice, "train/mae": mae, } ) @@ -184,6 +184,7 @@ def main(): # Evaluation round net.eval() accuracy = 0 + val_loss = 0 dice = 0 mae = 0 with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar: @@ -198,19 +199,22 @@ def main(): masks_pred = net(images) # compute metrics + val_loss += criterion(pred_masks, true_masks) + mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - accuracy += (true_masks == pred_masks_bin).float().sum() - dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False) - mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum") + accuracy += (true_masks == pred_masks_bin).float().mean() + dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar pbar.update(images.shape[0]) - accuracy /= len(ds_valid) - dice /= len(val_loader) # TODO: fix dice_coeff to not average - mae /= len(ds_valid) + accuracy /= len(val_loader) + val_loss /= len(val_loader) + dice /= len(val_loader) + mae /= len(val_loader) # save the last validation batch to table + table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) for i, (img, mask, pred) in enumerate( zip( images.to("cpu"), @@ -223,11 +227,13 @@ def main(): # log validation metrics wandb.log( { - "val/predictions": table, + "predictions": table, "val/accuracy": accuracy, + "val/bce": val_loss, "val/dice": dice, "val/mae": mae, - } + }, + commit=False, ) # update hyperparameters