124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
"""python test.py --model pointMLP --msg 20220209053148-404."""
|
|
import argparse
|
|
import datetime
|
|
import os
|
|
|
|
import models as models
|
|
import numpy as np
|
|
import sklearn.metrics as metrics
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.nn.parallel
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import torch.utils.data.distributed
|
|
from data import ModelNet40
|
|
from helper import cal_loss
|
|
from torch.utils.data import DataLoader
|
|
from utils import progress_bar
|
|
|
|
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
|
|
|
|
|
|
def parse_args():
|
|
"""Parameters"""
|
|
parser = argparse.ArgumentParser("training")
|
|
parser.add_argument(
|
|
"-c",
|
|
"--checkpoint",
|
|
type=str,
|
|
metavar="PATH",
|
|
help="path to save checkpoint (default: checkpoint)",
|
|
)
|
|
parser.add_argument("--msg", type=str, help="message after checkpoint")
|
|
parser.add_argument("--batch_size", type=int, default=16, help="batch size in training")
|
|
parser.add_argument("--model", default="pointMLP", help="model name [default: pointnet_cls]")
|
|
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
|
|
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
print(f"args: {args}")
|
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
|
|
|
if torch.cuda.is_available():
|
|
device = "cuda"
|
|
else:
|
|
device = "cpu"
|
|
print(f"==> Using device: {device}")
|
|
# if args.msg is None:
|
|
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
|
# else:
|
|
# message = "-"+args.msg
|
|
# args.checkpoint = 'checkpoints/' + args.model + message
|
|
if args.checkpoint is not None:
|
|
print(f"==> Using checkpoint: {args.checkpoint}")
|
|
|
|
print("==> Preparing data..")
|
|
test_loader = DataLoader(
|
|
ModelNet40(partition="test", num_points=args.num_points),
|
|
num_workers=4,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
drop_last=False,
|
|
)
|
|
# Model
|
|
print("==> Building model..")
|
|
net = models.__dict__[args.model]()
|
|
criterion = cal_loss
|
|
net = net.to(device)
|
|
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
|
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
|
|
# criterion = criterion.to(device)
|
|
if device == "cuda":
|
|
net = torch.nn.DataParallel(net)
|
|
cudnn.benchmark = True
|
|
net.load_state_dict(checkpoint["net"])
|
|
|
|
test_out = validate(net, test_loader, criterion, device)
|
|
print(f"Vanilla out: {test_out}")
|
|
|
|
|
|
def validate(net, testloader, criterion, device):
|
|
net.eval()
|
|
test_loss = 0
|
|
correct = 0
|
|
total = 0
|
|
test_true = []
|
|
test_pred = []
|
|
time_cost = datetime.datetime.now()
|
|
with torch.no_grad():
|
|
for batch_idx, (data, label) in enumerate(testloader):
|
|
data, label = data.to(device), label.to(device).squeeze()
|
|
data = data.permute(0, 2, 1)
|
|
logits = net(data)
|
|
loss = criterion(logits, label)
|
|
test_loss += loss.item()
|
|
preds = logits.max(dim=1)[1]
|
|
test_true.append(label.cpu().numpy())
|
|
test_pred.append(preds.detach().cpu().numpy())
|
|
total += label.size(0)
|
|
correct += preds.eq(label).sum().item()
|
|
progress_bar(
|
|
batch_idx,
|
|
len(testloader),
|
|
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
|
|
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
|
|
)
|
|
|
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
|
test_true = np.concatenate(test_true)
|
|
test_pred = np.concatenate(test_pred)
|
|
return {
|
|
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
|
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
|
|
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
|
|
"time": time_cost,
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|