mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
feat: prediction script
Former-commit-id: dcaba9f9fbeaec393cea168a16287c690c5733b0 [formerly c69d87581930858ca293326002588a0188431fe7] Former-commit-id: 39fd4965182a2c4ae8ffac2f20c7dbaf4b82a61f
This commit is contained in:
parent
c73b803a15
commit
9fe76d8c61
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,8 +1,6 @@
|
||||||
.venv/
|
.venv/
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
checkpoints/
|
|
||||||
wandb/
|
wandb/
|
||||||
|
|
||||||
INTERRUPTED.pth
|
*.pth
|
||||||
|
|
22
.vscode/launch.json
vendored
Normal file
22
.vscode/launch.json
vendored
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: Current File",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": [
|
||||||
|
"--input",
|
||||||
|
"SM.png",
|
||||||
|
"--output",
|
||||||
|
"test.png",
|
||||||
|
],
|
||||||
|
"justMyCode": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
1
SM.png.REMOVED.git-id
Normal file
1
SM.png.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
||||||
|
c6d08aa612451072cfe32a3ee086d08342ed9dd9
|
108
src/predict.py
108
src/predict.py
|
@ -1,16 +1,13 @@
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from src.utils.dataset import BasicDataset
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.utils import plot_img_and_mask
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -20,7 +17,7 @@ def get_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
"-m",
|
"-m",
|
||||||
default="MODEL.pth",
|
default="model.pth",
|
||||||
metavar="FILE",
|
metavar="FILE",
|
||||||
help="Specify the file in which the model is stored",
|
help="Specify the file in which the model is stored",
|
||||||
)
|
)
|
||||||
|
@ -28,7 +25,6 @@ def get_args():
|
||||||
"--input",
|
"--input",
|
||||||
"-i",
|
"-i",
|
||||||
metavar="INPUT",
|
metavar="INPUT",
|
||||||
nargs="+",
|
|
||||||
help="Filenames of input images",
|
help="Filenames of input images",
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
@ -36,108 +32,62 @@ def get_args():
|
||||||
"--output",
|
"--output",
|
||||||
"-o",
|
"-o",
|
||||||
metavar="OUTPUT",
|
metavar="OUTPUT",
|
||||||
nargs="+",
|
|
||||||
help="Filenames of output images",
|
help="Filenames of output images",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--viz",
|
"--threshold",
|
||||||
"-v",
|
|
||||||
action="store_true",
|
|
||||||
help="Visualize the images as they are processed",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-save",
|
|
||||||
"-n",
|
|
||||||
action="store_true",
|
|
||||||
help="Do not save the output masks",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mask-threshold",
|
|
||||||
"-t",
|
"-t",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help="Minimum probability value to consider a mask pixel white",
|
help="Minimum probability value to consider a mask pixel white",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--scale",
|
|
||||||
"-s",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Scale factor for the input images",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
|
def predict_img(net, img, device, threshold):
|
||||||
net.eval()
|
|
||||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
|
||||||
img = img.unsqueeze(0)
|
img = img.unsqueeze(0)
|
||||||
img = img.to(device=device, dtype=torch.float32)
|
img = img.to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
net.eval()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output = net(img)
|
output = net(img)
|
||||||
|
preds = torch.sigmoid(output)[0]
|
||||||
|
full_mask = preds.cpu().squeeze()
|
||||||
|
|
||||||
probs = torch.sigmoid(output)[0]
|
return np.asarray(full_mask > threshold)
|
||||||
|
|
||||||
tf = transforms.Compose(
|
|
||||||
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
|
|
||||||
)
|
|
||||||
|
|
||||||
full_mask = tf(probs.cpu()).squeeze()
|
|
||||||
|
|
||||||
if net.n_classes == 1:
|
|
||||||
return (full_mask > out_threshold).numpy()
|
|
||||||
else:
|
|
||||||
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
|
||||||
|
|
||||||
|
|
||||||
def get_output_filenames(args):
|
|
||||||
def _generate_name(fn):
|
|
||||||
split = os.path.splitext(fn)
|
|
||||||
return f"{split[0]}_OUT{split[1]}"
|
|
||||||
|
|
||||||
return args.output or list(map(_generate_name, args.input))
|
|
||||||
|
|
||||||
|
|
||||||
def mask_to_image(mask: np.ndarray):
|
|
||||||
if mask.ndim == 2:
|
|
||||||
return Image.fromarray((mask * 255).astype(np.uint8))
|
|
||||||
elif mask.ndim == 3:
|
|
||||||
return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
in_files = args.input
|
|
||||||
out_files = get_output_filenames(args)
|
|
||||||
|
|
||||||
net = UNet(n_channels=3, n_classes=2)
|
net = UNet(n_channels=3, n_classes=1)
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logging.info(f"Loading model {args.model}")
|
|
||||||
logging.info(f"Using device {device}")
|
logging.info(f"Using device {device}")
|
||||||
|
|
||||||
|
logging.info("Transfering model to device")
|
||||||
net.to(device=device)
|
net.to(device=device)
|
||||||
|
|
||||||
|
logging.info(f"Loading model {args.model}")
|
||||||
net.load_state_dict(torch.load(args.model, map_location=device))
|
net.load_state_dict(torch.load(args.model, map_location=device))
|
||||||
|
|
||||||
logging.info("Model loaded!")
|
logging.info(f"Loading image {args.input}")
|
||||||
|
img = Image.open(args.input).convert("RGB")
|
||||||
|
|
||||||
for i, filename in enumerate(in_files):
|
logging.info(f"Preprocessing image {args.input}")
|
||||||
|
tf = A.Compose(
|
||||||
|
[
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
aug = tf(image=np.asarray(img))
|
||||||
|
img = aug["image"]
|
||||||
|
|
||||||
logging.info(f"\nPredicting image {filename} ...")
|
logging.info(f"Predicting image {args.input}")
|
||||||
img = Image.open(filename)
|
mask = predict_img(net=net, img=img, threshold=args.threshold, device=device)
|
||||||
|
|
||||||
mask = predict_img(
|
logging.info(f"Saving prediction to {args.output}")
|
||||||
net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, device=device
|
mask = Image.fromarray(mask)
|
||||||
)
|
mask.write(args.output)
|
||||||
|
|
||||||
if not args.no_save:
|
|
||||||
out_filename = out_files[i]
|
|
||||||
result = mask_to_image(mask)
|
|
||||||
result.save(out_filename)
|
|
||||||
logging.info(f"Mask saved to {out_filename}")
|
|
||||||
|
|
||||||
if args.viz:
|
|
||||||
logging.info(f"Visualizing results for image {filename}, close to continue...")
|
|
||||||
plot_img_and_mask(img, mask)
|
|
||||||
|
|
16
src/train.py
16
src/train.py
|
@ -105,6 +105,9 @@ def main():
|
||||||
net.load_state_dict(torch.load(args.load, map_location=device))
|
net.load_state_dict(torch.load(args.load, map_location=device))
|
||||||
logging.info(f"Model loaded from {args.load}")
|
logging.info(f"Model loaded from {args.load}")
|
||||||
|
|
||||||
|
# save initial model.pth
|
||||||
|
torch.save(net.state_dict(), "model.pth")
|
||||||
|
|
||||||
# transfer network to device
|
# transfer network to device
|
||||||
net.to(device=device)
|
net.to(device=device)
|
||||||
|
|
||||||
|
@ -146,7 +149,7 @@ def main():
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
wandb.init(
|
run = wandb.init(
|
||||||
project="U-Net-tmp",
|
project="U-Net-tmp",
|
||||||
config=dict(
|
config=dict(
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
|
@ -156,6 +159,9 @@ def main():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
wandb.watch(net, log_freq=100)
|
wandb.watch(net, log_freq=100)
|
||||||
|
# artifact = wandb.Artifact("model", type="model")
|
||||||
|
# artifact.add_file("model.pth")
|
||||||
|
# run.log_artifact(artifact)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"""Starting training:
|
f"""Starting training:
|
||||||
|
@ -222,9 +228,11 @@ def main():
|
||||||
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
|
# torch.save(net.state_dict(), "model.pth")
|
||||||
torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch)))
|
# run.log_artifact(artifact)
|
||||||
logging.info(f"Checkpoint {epoch} saved!")
|
logging.info(f"model saved!")
|
||||||
|
|
||||||
|
run.finish()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
|
|
Loading…
Reference in a new issue