diff --git a/classification_ModelNet40/test.py b/classification_ModelNet40/test.py new file mode 100644 index 0000000..e960e06 --- /dev/null +++ b/classification_ModelNet40/test.py @@ -0,0 +1,109 @@ +""" +python test.py --model pointMLP --msg 20220209053148-404 +""" +import argparse +import os +import datetime +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +from torch.utils.data import DataLoader +import models as models +from utils import progress_bar, IOStream +from data import ModelNet40 +import sklearn.metrics as metrics +from helper import cal_loss +import numpy as np +import torch.nn.functional as F + +model_names = sorted(name for name in models.__dict__ + if callable(models.__dict__[name])) + + +def parse_args(): + """Parameters""" + parser = argparse.ArgumentParser('training') + parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH', + help='path to save checkpoint (default: checkpoint)') + parser.add_argument('--msg', type=str, help='message after checkpoint') + parser.add_argument('--batch_size', type=int, default=16, help='batch size in training') + parser.add_argument('--model', default='pointMLP', help='model name [default: pointnet_cls]') + parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40') + parser.add_argument('--num_points', type=int, default=1024, help='Point Number') + return parser.parse_args() + +def main(): + args = parse_args() + print(f"args: {args}") + os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + + if torch.cuda.is_available(): + device = 'cuda' + else: + device = 'cpu' + print(f"==> Using device: {device}") + if args.msg is None: + message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) + else: + message = "-"+args.msg + args.checkpoint = 'checkpoints/' + args.model + message + + print('==> Preparing data..') + test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4, + batch_size=args.batch_size, shuffle=False, drop_last=False) + # Model + print('==> Building model..') + net = models.__dict__[args.model]() + criterion = cal_loss + net = net.to(device) + checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth') + checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) + # criterion = criterion.to(device) + if device == 'cuda': + net = torch.nn.DataParallel(net) + cudnn.benchmark = True + net.load_state_dict(checkpoint['net']) + + test_out = validate(net, test_loader, criterion, device) + print(f"Vanilla out: {test_out}") + + +def validate(net, testloader, criterion, device): + net.eval() + test_loss = 0 + correct = 0 + total = 0 + test_true = [] + test_pred = [] + time_cost = datetime.datetime.now() + with torch.no_grad(): + for batch_idx, (data, label) in enumerate(testloader): + data, label = data.to(device), label.to(device).squeeze() + data = data.permute(0, 2, 1) + logits = net(data) + loss = criterion(logits, label) + test_loss += loss.item() + preds = logits.max(dim=1)[1] + test_true.append(label.cpu().numpy()) + test_pred.append(preds.detach().cpu().numpy()) + total += label.size(0) + correct += preds.eq(label).sum().item() + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) + + time_cost = int((datetime.datetime.now() - time_cost).total_seconds()) + test_true = np.concatenate(test_true) + test_pred = np.concatenate(test_pred) + return { + "loss": float("%.3f" % (test_loss / (batch_idx + 1))), + "acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))), + "acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))), + "time": time_cost + } + + +if __name__ == '__main__': + main()