diff --git a/src/train.py b/src/train.py index b0e7c84..d8aff3e 100644 --- a/src/train.py +++ b/src/train.py @@ -137,7 +137,7 @@ if __name__ == "__main__": wandb.run.log_artifact(artifact) # log gradients and weights four time per epoch - wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4) + wandb.watch(net, criterion, log_freq=100) # print the config logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") @@ -198,92 +198,87 @@ if __name__ == "__main__": } ) - # Evaluation round - net.eval() - accuracy = 0 - val_loss = 0 - dice = 0 - mae = 0 - with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar: - for images, masks_true in val_loader: + if step and (step % 100 == 0 or step == len(train_loader)): + # Evaluation round + net.eval() + accuracy = 0 + val_loss = 0 + dice = 0 + mae = 0 + with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2: + for images, masks_true in val_loader: - # transfer images to device - images = images.to(device=device) - masks_true = masks_true.unsqueeze(1).to(device=device) + # transfer images to device + images = images.to(device=device) + masks_true = masks_true.unsqueeze(1).to(device=device) - # forward - with torch.inference_mode(): - masks_pred = net(images) + # forward + with torch.inference_mode(): + masks_pred = net(images) - # compute metrics - val_loss += criterion(pred_masks, true_masks) - mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks) - masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - accuracy += (true_masks == pred_masks_bin).float().mean() - dice += dice_coeff(masks_pred_bin, masks_true) + # compute metrics + val_loss += criterion(pred_masks, true_masks) + mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + accuracy += (true_masks == pred_masks_bin).float().mean() + dice += dice_coeff(masks_pred_bin, masks_true) - # update progress bar - pbar.update(images.shape[0]) + # update progress bar + pbar2.update(images.shape[0]) - accuracy /= len(val_loader) - val_loss /= len(val_loader) - dice /= len(val_loader) - mae /= len(val_loader) + accuracy /= len(val_loader) + val_loss /= len(val_loader) + dice /= len(val_loader) + mae /= len(val_loader) - # save the last validation batch to table - table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) - for i, (img, mask, pred, pred_bin) in enumerate( - zip( - images.to("cpu"), - masks_true.to("cpu"), - masks_pred.to("cpu"), - masks_pred_bin.to("cpu").squeeze().int().numpy(), - ) - ): - table.add_data( - i, - wandb.Image(img), - wandb.Image(mask), - wandb.Image( - pred, - masks={ - "predictions": { - "mask_data": pred_bin, - "class_labels": class_labels, - }, + # save the last validation batch to table + table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.to("cpu"), + masks_true.to("cpu"), + masks_pred.to("cpu"), + masks_pred_bin.to("cpu").squeeze().int().numpy(), + ) + ): + table.add_data( + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ) + + # log validation metrics + wandb.log( + { + "predictions": table, + "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], + "val/accuracy": accuracy, + "val/bce": val_loss, + "val/dice": dice, + "val/mae": mae, }, - ), - ) + commit=False, + ) - # log validation metrics - wandb.log( - { - "predictions": table, - "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], - "val/accuracy": accuracy, - "val/bce": val_loss, - "val/dice": dice, - "val/mae": mae, - }, - commit=False, - ) + # update hyperparameters + net.train() + scheduler.step(dice) - # update hyperparameters - net.train() - scheduler.step(dice) - - # save weights when epoch end - torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth") - artifact = wandb.Artifact("pth", type="model") - artifact.add_file(f"checkpoints/model-{epoch}.pth") - wandb.run.log_artifact(artifact) - - # export model to onnx format - dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) - torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx") - artifact = wandb.Artifact("onnx", type="model") - artifact.add_file(f"checkpoints/model-{epoch}.onnx") - wandb.run.log_artifact(artifact) + # export model to onnx format when validation ends + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) + torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx") + artifact = wandb.Artifact("onnx", type="model") + artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx") + wandb.run.log_artifact(artifact) # stop wandb wandb.run.finish()