diff --git a/src/evaluate.py b/src/evaluate.py index cd5db85..1d0995c 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -20,18 +20,18 @@ def evaluate(net, dataloader, device): # forward, predict the mask with torch.inference_mode(): masks_pred = net(images) - masks_pred = (torch.sigmoid(masks_pred) > 0.5).float() + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() # compute the Dice score - dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False) + dice_score += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False) # update progress bar pbar.update(images.shape[0]) # save some images to wandb - table = wandb.Table(columns=["image", "mask", "prediction"]) - for img, mask, pred in zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu")): - table.add_data(wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) + table = wandb.Table(columns=["id", "image", "mask", "prediction"]) + for i, (img, mask, pred) in enumerate(zip(images.to("cpu"), masks_true.to("cpu"), masks_pred.to("cpu"))): + table.add_data(i, wandb.Image(img), wandb.Image(mask), wandb.Image(pred)) wandb.log({"predictions_table": table}, commit=False) net.train() diff --git a/src/predict.py b/src/predict.py index 6d569b4..c1a445c 100755 --- a/src/predict.py +++ b/src/predict.py @@ -13,32 +13,6 @@ from unet import UNet from utils.utils import plot_img_and_mask -def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): - net.eval() - img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) - img = img.unsqueeze(0) - img = img.to(device=device, dtype=torch.float32) - - with torch.no_grad(): - output = net(img) - - if net.n_classes > 1: - probs = F.softmax(output, dim=1)[0] - else: - probs = torch.sigmoid(output)[0] - - tf = transforms.Compose( - [transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()] - ) - - full_mask = tf(probs.cpu()).squeeze() - - if net.n_classes == 1: - return (full_mask > out_threshold).numpy() - else: - return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy() - - def get_args(): parser = argparse.ArgumentParser( description="Predict masks from input images", @@ -95,6 +69,29 @@ def get_args(): return parser.parse_args() +def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): + net.eval() + img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) + img = img.unsqueeze(0) + img = img.to(device=device, dtype=torch.float32) + + with torch.inference_mode(): + output = net(img) + + probs = torch.sigmoid(output)[0] + + tf = transforms.Compose( + [transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()] + ) + + full_mask = tf(probs.cpu()).squeeze() + + if net.n_classes == 1: + return (full_mask > out_threshold).numpy() + else: + return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy() + + def get_output_filenames(args): def _generate_name(fn): split = os.path.splitext(fn) @@ -127,6 +124,7 @@ if __name__ == "__main__": logging.info("Model loaded!") for i, filename in enumerate(in_files): + logging.info(f"\nPredicting image {filename} ...") img = Image.open(filename) diff --git a/src/train.py b/src/train.py index 809244f..1fd7468 100644 --- a/src/train.py +++ b/src/train.py @@ -111,10 +111,11 @@ def main(): # 1. Create transforms tf_train = A.Compose( [ - A.Resize(500, 500), + A.Resize(512, 512), A.Flip(), A.ColorJitter(), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), ToTensorV2(), @@ -122,7 +123,7 @@ def main(): ) tf_valid = A.Compose( [ - A.Resize(500, 500), + A.Resize(512, 512), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), A.ToFloat(max_value=255), ToTensorV2(), @@ -154,7 +155,7 @@ def main(): amp=args.amp, ), ) - wandb.save(f"{CHECKPOINT_DIR}/*") + wandb.watch(net, log_freq=100) logging.info( f"""Starting training: