feat: prediction script

Former-commit-id: dcaba9f9fbeaec393cea168a16287c690c5733b0 [formerly c69d87581930858ca293326002588a0188431fe7]
Former-commit-id: 39fd4965182a2c4ae8ffac2f20c7dbaf4b82a61f
This commit is contained in:
Laurent Fainsin 2022-06-29 16:12:00 +02:00
parent c73b803a15
commit 9fe76d8c61
6 changed files with 65 additions and 86 deletions

4
.gitignore vendored
View file

@ -1,8 +1,6 @@
.venv/ .venv/
.mypy_cache/ .mypy_cache/
__pycache__/ __pycache__/
checkpoints/
wandb/ wandb/
INTERRUPTED.pth *.pth

22
.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,22 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--input",
"SM.png",
"--output",
"test.png",
],
"justMyCode": true
}
]
}

1
SM.png.REMOVED.git-id Normal file
View file

@ -0,0 +1 @@
c6d08aa612451072cfe32a3ee086d08342ed9dd9

View file

@ -1,16 +1,13 @@
import argparse import argparse
import logging import logging
import os
import albumentations as A
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F from albumentations.pytorch import ToTensorV2
from PIL import Image from PIL import Image
from torchvision import transforms
from src.utils.dataset import BasicDataset
from unet import UNet from unet import UNet
from utils.utils import plot_img_and_mask
def get_args(): def get_args():
@ -20,7 +17,7 @@ def get_args():
parser.add_argument( parser.add_argument(
"--model", "--model",
"-m", "-m",
default="MODEL.pth", default="model.pth",
metavar="FILE", metavar="FILE",
help="Specify the file in which the model is stored", help="Specify the file in which the model is stored",
) )
@ -28,7 +25,6 @@ def get_args():
"--input", "--input",
"-i", "-i",
metavar="INPUT", metavar="INPUT",
nargs="+",
help="Filenames of input images", help="Filenames of input images",
required=True, required=True,
) )
@ -36,108 +32,62 @@ def get_args():
"--output", "--output",
"-o", "-o",
metavar="OUTPUT", metavar="OUTPUT",
nargs="+",
help="Filenames of output images", help="Filenames of output images",
) )
parser.add_argument( parser.add_argument(
"--viz", "--threshold",
"-v",
action="store_true",
help="Visualize the images as they are processed",
)
parser.add_argument(
"--no-save",
"-n",
action="store_true",
help="Do not save the output masks",
)
parser.add_argument(
"--mask-threshold",
"-t", "-t",
type=float, type=float,
default=0.5, default=0.5,
help="Minimum probability value to consider a mask pixel white", help="Minimum probability value to consider a mask pixel white",
) )
parser.add_argument(
"--scale",
"-s",
type=float,
default=0.5,
help="Scale factor for the input images",
)
return parser.parse_args() return parser.parse_args()
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): def predict_img(net, img, device, threshold):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0) img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32) img = img.to(device=device, dtype=torch.float32)
net.eval()
with torch.inference_mode(): with torch.inference_mode():
output = net(img) output = net(img)
preds = torch.sigmoid(output)[0]
full_mask = preds.cpu().squeeze()
probs = torch.sigmoid(output)[0] return np.asarray(full_mask > threshold)
tf = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
)
full_mask = tf(probs.cpu()).squeeze()
if net.n_classes == 1:
return (full_mask > out_threshold).numpy()
else:
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
def get_output_filenames(args):
def _generate_name(fn):
split = os.path.splitext(fn)
return f"{split[0]}_OUT{split[1]}"
return args.output or list(map(_generate_name, args.input))
def mask_to_image(mask: np.ndarray):
if mask.ndim == 2:
return Image.fromarray((mask * 255).astype(np.uint8))
elif mask.ndim == 3:
return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=2) net = UNet(n_channels=3, n_classes=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Loading model {args.model}")
logging.info(f"Using device {device}") logging.info(f"Using device {device}")
logging.info("Transfering model to device")
net.to(device=device) net.to(device=device)
logging.info(f"Loading model {args.model}")
net.load_state_dict(torch.load(args.model, map_location=device)) net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded!") logging.info(f"Loading image {args.input}")
img = Image.open(args.input).convert("RGB")
for i, filename in enumerate(in_files): logging.info(f"Preprocessing image {args.input}")
tf = A.Compose(
[
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
aug = tf(image=np.asarray(img))
img = aug["image"]
logging.info(f"\nPredicting image {filename} ...") logging.info(f"Predicting image {args.input}")
img = Image.open(filename) mask = predict_img(net=net, img=img, threshold=args.threshold, device=device)
mask = predict_img( logging.info(f"Saving prediction to {args.output}")
net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, device=device mask = Image.fromarray(mask)
) mask.write(args.output)
if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(mask)
result.save(out_filename)
logging.info(f"Mask saved to {out_filename}")
if args.viz:
logging.info(f"Visualizing results for image {filename}, close to continue...")
plot_img_and_mask(img, mask)

View file

@ -105,6 +105,9 @@ def main():
net.load_state_dict(torch.load(args.load, map_location=device)) net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f"Model loaded from {args.load}") logging.info(f"Model loaded from {args.load}")
# save initial model.pth
torch.save(net.state_dict(), "model.pth")
# transfer network to device # transfer network to device
net.to(device=device) net.to(device=device)
@ -146,7 +149,7 @@ def main():
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
# setup wandb # setup wandb
wandb.init( run = wandb.init(
project="U-Net-tmp", project="U-Net-tmp",
config=dict( config=dict(
epochs=args.epochs, epochs=args.epochs,
@ -156,6 +159,9 @@ def main():
), ),
) )
wandb.watch(net, log_freq=100) wandb.watch(net, log_freq=100)
# artifact = wandb.Artifact("model", type="model")
# artifact.add_file("model.pth")
# run.log_artifact(artifact)
logging.info( logging.info(
f"""Starting training: f"""Starting training:
@ -222,9 +228,11 @@ def main():
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
# save weights when epoch end # save weights when epoch end
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) # torch.save(net.state_dict(), "model.pth")
torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch))) # run.log_artifact(artifact)
logging.info(f"Checkpoint {epoch} saved!") logging.info(f"model saved!")
run.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")

BIN
test.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB