diff --git a/src/unet/model.py b/src/unet/model.py index 93872f5..4556640 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -195,7 +195,7 @@ class UNet(pl.LightningModule): accuracy = (masks_true == masks_pred_bin).float().mean() rows = [] - if batch_idx % 50 == 0 or dice < 0.1: + if batch_idx % 50 == 0 or dice > 0.9: for i, (img, mask, pred, pred_bin) in enumerate( zip( images.cpu(),