PointMLP/classification_ModelNet40/test.py

124 lines
4.1 KiB
Python
Raw Normal View History

2023-08-03 14:40:14 +00:00
"""python test.py --model pointMLP --msg 20220209053148-404."""
2022-03-08 19:03:33 +00:00
import argparse
import datetime
2023-08-03 14:40:14 +00:00
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
2022-03-08 19:03:33 +00:00
import torch
import torch.backends.cudnn as cudnn
2023-08-03 14:40:14 +00:00
import torch.nn.parallel
2022-03-08 19:03:33 +00:00
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from data import ModelNet40
from helper import cal_loss
2023-08-03 14:40:14 +00:00
from torch.utils.data import DataLoader
from utils import progress_bar
2022-03-08 19:03:33 +00:00
2023-08-03 14:40:14 +00:00
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
2022-03-08 19:03:33 +00:00
def parse_args():
"""Parameters"""
2023-08-03 14:40:14 +00:00
parser = argparse.ArgumentParser("training")
parser.add_argument(
"-c",
"--checkpoint",
type=str,
metavar="PATH",
help="path to save checkpoint (default: checkpoint)",
)
parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument("--batch_size", type=int, default=16, help="batch size in training")
parser.add_argument("--model", default="pointMLP", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
2022-03-08 19:03:33 +00:00
return parser.parse_args()
2023-08-03 14:40:14 +00:00
2022-03-08 19:03:33 +00:00
def main():
args = parse_args()
print(f"args: {args}")
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
if torch.cuda.is_available():
2023-08-03 14:40:14 +00:00
device = "cuda"
2022-03-08 19:03:33 +00:00
else:
2023-08-03 14:40:14 +00:00
device = "cpu"
2022-03-08 19:03:33 +00:00
print(f"==> Using device: {device}")
2023-08-03 13:39:35 +00:00
# if args.msg is None:
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
# else:
# message = "-"+args.msg
# args.checkpoint = 'checkpoints/' + args.model + message
if args.checkpoint is not None:
print(f"==> Using checkpoint: {args.checkpoint}")
2022-03-08 19:03:33 +00:00
2023-08-03 14:40:14 +00:00
print("==> Preparing data..")
2023-08-03 13:39:35 +00:00
test_loader = DataLoader(
2023-08-03 14:40:14 +00:00
ModelNet40(partition="test", num_points=args.num_points),
2023-08-03 13:39:35 +00:00
num_workers=4,
batch_size=args.batch_size,
shuffle=False,
2023-08-03 14:40:14 +00:00
drop_last=False,
2023-08-03 13:39:35 +00:00
)
2022-03-08 19:03:33 +00:00
# Model
2023-08-03 14:40:14 +00:00
print("==> Building model..")
2022-03-08 19:03:33 +00:00
net = models.__dict__[args.model]()
criterion = cal_loss
net = net.to(device)
2023-08-03 13:39:35 +00:00
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
2023-08-03 14:40:14 +00:00
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
2022-03-08 19:03:33 +00:00
# criterion = criterion.to(device)
2023-08-03 14:40:14 +00:00
if device == "cuda":
2022-03-08 19:03:33 +00:00
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
2023-08-03 14:40:14 +00:00
net.load_state_dict(checkpoint["net"])
2022-03-08 19:03:33 +00:00
test_out = validate(net, test_loader, criterion, device)
print(f"Vanilla out: {test_out}")
def validate(net, testloader, criterion, device):
net.eval()
test_loss = 0
correct = 0
total = 0
test_true = []
test_pred = []
time_cost = datetime.datetime.now()
with torch.no_grad():
for batch_idx, (data, label) in enumerate(testloader):
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
logits = net(data)
loss = criterion(logits, label)
test_loss += loss.item()
preds = logits.max(dim=1)[1]
test_true.append(label.cpu().numpy())
test_pred.append(preds.detach().cpu().numpy())
total += label.size(0)
correct += preds.eq(label).sum().item()
2023-08-03 14:40:14 +00:00
progress_bar(
batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
2022-03-08 19:03:33 +00:00
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
2023-08-03 14:40:14 +00:00
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost,
2022-03-08 19:03:33 +00:00
}
2023-08-03 14:40:14 +00:00
if __name__ == "__main__":
2022-03-08 19:03:33 +00:00
main()