From 9fe76d8c6102ac8dc184cad3d36a49f2612c4464 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 29 Jun 2022 16:12:00 +0200 Subject: [PATCH] feat: prediction script Former-commit-id: dcaba9f9fbeaec393cea168a16287c690c5733b0 [formerly c69d87581930858ca293326002588a0188431fe7] Former-commit-id: 39fd4965182a2c4ae8ffac2f20c7dbaf4b82a61f --- .gitignore | 4 +- .vscode/launch.json | 22 +++++++++ SM.png.REMOVED.git-id | 1 + src/predict.py | 108 ++++++++++++------------------------------ src/train.py | 16 +++++-- test.png | Bin 0 -> 1575 bytes 6 files changed, 65 insertions(+), 86 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 SM.png.REMOVED.git-id create mode 100644 test.png diff --git a/.gitignore b/.gitignore index c7bb7ca..9d60c97 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,6 @@ .venv/ .mypy_cache/ __pycache__/ - -checkpoints/ wandb/ -INTERRUPTED.pth +*.pth diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..f7a9165 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,22 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "args": [ + "--input", + "SM.png", + "--output", + "test.png", + ], + "justMyCode": true + } + ] +} diff --git a/SM.png.REMOVED.git-id b/SM.png.REMOVED.git-id new file mode 100644 index 0000000..21a947e --- /dev/null +++ b/SM.png.REMOVED.git-id @@ -0,0 +1 @@ +c6d08aa612451072cfe32a3ee086d08342ed9dd9 \ No newline at end of file diff --git a/src/predict.py b/src/predict.py index c1a445c..587e11e 100755 --- a/src/predict.py +++ b/src/predict.py @@ -1,16 +1,13 @@ import argparse import logging -import os +import albumentations as A import numpy as np import torch -import torch.nn.functional as F +from albumentations.pytorch import ToTensorV2 from PIL import Image -from torchvision import transforms -from src.utils.dataset import BasicDataset from unet import UNet -from utils.utils import plot_img_and_mask def get_args(): @@ -20,7 +17,7 @@ def get_args(): parser.add_argument( "--model", "-m", - default="MODEL.pth", + default="model.pth", metavar="FILE", help="Specify the file in which the model is stored", ) @@ -28,7 +25,6 @@ def get_args(): "--input", "-i", metavar="INPUT", - nargs="+", help="Filenames of input images", required=True, ) @@ -36,108 +32,62 @@ def get_args(): "--output", "-o", metavar="OUTPUT", - nargs="+", help="Filenames of output images", ) parser.add_argument( - "--viz", - "-v", - action="store_true", - help="Visualize the images as they are processed", - ) - parser.add_argument( - "--no-save", - "-n", - action="store_true", - help="Do not save the output masks", - ) - parser.add_argument( - "--mask-threshold", + "--threshold", "-t", type=float, default=0.5, help="Minimum probability value to consider a mask pixel white", ) - parser.add_argument( - "--scale", - "-s", - type=float, - default=0.5, - help="Scale factor for the input images", - ) return parser.parse_args() -def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): - net.eval() - img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) +def predict_img(net, img, device, threshold): img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) + net.eval() with torch.inference_mode(): output = net(img) + preds = torch.sigmoid(output)[0] + full_mask = preds.cpu().squeeze() - probs = torch.sigmoid(output)[0] - - tf = transforms.Compose( - [transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()] - ) - - full_mask = tf(probs.cpu()).squeeze() - - if net.n_classes == 1: - return (full_mask > out_threshold).numpy() - else: - return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy() - - -def get_output_filenames(args): - def _generate_name(fn): - split = os.path.splitext(fn) - return f"{split[0]}_OUT{split[1]}" - - return args.output or list(map(_generate_name, args.input)) - - -def mask_to_image(mask: np.ndarray): - if mask.ndim == 2: - return Image.fromarray((mask * 255).astype(np.uint8)) - elif mask.ndim == 3: - return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8)) + return np.asarray(full_mask > threshold) if __name__ == "__main__": args = get_args() - in_files = args.input - out_files = get_output_filenames(args) - net = UNet(n_channels=3, n_classes=2) + net = UNet(n_channels=3, n_classes=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logging.info(f"Loading model {args.model}") logging.info(f"Using device {device}") + logging.info("Transfering model to device") net.to(device=device) + + logging.info(f"Loading model {args.model}") net.load_state_dict(torch.load(args.model, map_location=device)) - logging.info("Model loaded!") + logging.info(f"Loading image {args.input}") + img = Image.open(args.input).convert("RGB") - for i, filename in enumerate(in_files): + logging.info(f"Preprocessing image {args.input}") + tf = A.Compose( + [ + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + aug = tf(image=np.asarray(img)) + img = aug["image"] - logging.info(f"\nPredicting image {filename} ...") - img = Image.open(filename) + logging.info(f"Predicting image {args.input}") + mask = predict_img(net=net, img=img, threshold=args.threshold, device=device) - mask = predict_img( - net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, device=device - ) - - if not args.no_save: - out_filename = out_files[i] - result = mask_to_image(mask) - result.save(out_filename) - logging.info(f"Mask saved to {out_filename}") - - if args.viz: - logging.info(f"Visualizing results for image {filename}, close to continue...") - plot_img_and_mask(img, mask) + logging.info(f"Saving prediction to {args.output}") + mask = Image.fromarray(mask) + mask.write(args.output) diff --git a/src/train.py b/src/train.py index 1fd7468..4f56bde 100644 --- a/src/train.py +++ b/src/train.py @@ -105,6 +105,9 @@ def main(): net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f"Model loaded from {args.load}") + # save initial model.pth + torch.save(net.state_dict(), "model.pth") + # transfer network to device net.to(device=device) @@ -146,7 +149,7 @@ def main(): criterion = nn.BCEWithLogitsLoss() # setup wandb - wandb.init( + run = wandb.init( project="U-Net-tmp", config=dict( epochs=args.epochs, @@ -156,6 +159,9 @@ def main(): ), ) wandb.watch(net, log_freq=100) + # artifact = wandb.Artifact("model", type="model") + # artifact.add_file("model.pth") + # run.log_artifact(artifact) logging.info( f"""Starting training: @@ -222,9 +228,11 @@ def main(): print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") # save weights when epoch end - Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) - torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch))) - logging.info(f"Checkpoint {epoch} saved!") + # torch.save(net.state_dict(), "model.pth") + # run.log_artifact(artifact) + logging.info(f"model saved!") + + run.finish() except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") diff --git a/test.png b/test.png new file mode 100644 index 0000000000000000000000000000000000000000..a0cad88ce1832fdd02c3e4968e3e838520b3af20 GIT binary patch literal 1575 zcmZ9Mdr*=I6ve5ru{9H$+H5mvEz3u1&&?Vi5oHmUFU%eqkk*c&NoBxiAzC>3Sgl+% zwG{JhrC^$-Kwws5rHPdW3Vdp1k`Ny%3Lzp}?e5g<{p&aPk281fIdjs30?_lVmRiAJ zu=$uhzE~J+c0CMcftdAW0kgIqV!&W?AdD|+e|(`DM3UsYm!yp4^{fuF*!^JS$P(L( zWe@#r?NIxMoIH2<*`an+J)aeKahW4}Z~@t3vFnjJ_S-Mn)=;t5&p)))&F|AV6H65b z<4n9j0I^~YJ++y6r^q7HmGsQCy z*NN?PyPlfugsVMXbgF15T@Fqm;mCeBIE&90?NmXe@t4i~d16v_>SexcIKo}~K5`nQ z3fzvaXu=9iKxoKO7{*M zhMf^f0Pwk63hrGR6{{Xr+?yA*(db%{({Dbug_>KbUmys;`W@7m>)@xQDaXGV_TzK@ z;pWh&*dV~Z+b6xZX!{M_h0%=ON)$8ETA5g2X%xBytBk+HgN>qIpY~+BF!<`cqeOYn zt3)%-lClK_o(0xdw>_<+2}5d0oo{c;;%Sn+_pO8KfJZu4`xNWh^amj}QtMdc zZ`p4+V3Ml~!9c}IjR~l`_hG%-Mt$1jroT9&aVEYqVSiVY%i<2hO!Oy_5|t-lixIrb zy31_ZRwKJ&qDR8w>g9MU!4RiCe243`Cer-88B$uZc98=U#pjm44BbqkMiT}OE|93Q zCU8$CJr@CZllLwAIQn173|@6HhrydMh1gjsr}Q+O%LottMstb1b01*ev59R)>a*shzl%PBctAOw)5(Kf3D15Ad zkFFia%&oo3!d9g{mA$?G8WdzXHKHW}am_j9ocb<3FJx8YjACHG>9!sqC>)2(1Lr$6 z2a#te)j1O@jtWOUdnOa4k61kjj_E@oU@TQmw0!fE=O-#%h8>YvY zA{ZAWiFWA0Iq=v`-W0F1jhR~IGf3Agi&t$>-6fD)Hf+|$T;%dM{DGdK%;h5Pm#d0N zi9i&@85qmnvT2W_mQE)l!y7py;oFqyyap}(^lugF2%+=+nlRjpb#vP~h=;v}Ol9I! zfyVk>?PN*ri>lQb#=~|!jbUV2G3QOG%!0CI0VpzLR+bsU2yS_2@jLYHIF)lm6-kq+ zJRoj8-zvE6wY4k*q9!hsshrf(G>*Dq>G$BZuTL5-%j#O#+(I4O-IPU1!?v8UorJ?& z0szUvW@jnC=%LT6%wxosb%-fVNhdCFFvSD#xk!ujj7a);H6u+^yY^bWO>imx#2fR{ zvZ>LyB&26pZvx*thiZ0b$4G|Xw$|<>-ta!npVK$=jZ1a!(TgdwmsKX?+knJK;WW5=7-8evhCouPZ6~gQa@a60{%J>`8b3oJp literal 0 HcmV?d00001