117 lines
4 KiB
Python
117 lines
4 KiB
Python
"""
|
|
python test.py --model pointMLP --msg 20220209053148-404
|
|
"""
|
|
import argparse
|
|
import os
|
|
import datetime
|
|
import torch
|
|
import torch.nn.parallel
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import torch.utils.data.distributed
|
|
from torch.utils.data import DataLoader
|
|
import models as models
|
|
from utils import progress_bar, IOStream
|
|
from data import ModelNet40
|
|
import sklearn.metrics as metrics
|
|
from helper import cal_loss
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
model_names = sorted(name for name in models.__dict__
|
|
if callable(models.__dict__[name]))
|
|
|
|
|
|
def parse_args():
|
|
"""Parameters"""
|
|
parser = argparse.ArgumentParser('training')
|
|
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
|
|
help='path to save checkpoint (default: checkpoint)')
|
|
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
|
parser.add_argument('--batch_size', type=int, default=16, help='batch size in training')
|
|
parser.add_argument('--model', default='pointMLP', help='model name [default: pointnet_cls]')
|
|
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
args = parse_args()
|
|
print(f"args: {args}")
|
|
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
|
|
|
if torch.cuda.is_available():
|
|
device = 'cuda'
|
|
else:
|
|
device = 'cpu'
|
|
print(f"==> Using device: {device}")
|
|
# if args.msg is None:
|
|
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
|
# else:
|
|
# message = "-"+args.msg
|
|
# args.checkpoint = 'checkpoints/' + args.model + message
|
|
if args.checkpoint is not None:
|
|
print(f"==> Using checkpoint: {args.checkpoint}")
|
|
|
|
print('==> Preparing data..')
|
|
test_loader = DataLoader(
|
|
ModelNet40(partition='test', num_points=args.num_points),
|
|
num_workers=4,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
drop_last=False
|
|
)
|
|
# Model
|
|
print('==> Building model..')
|
|
net = models.__dict__[args.model]()
|
|
criterion = cal_loss
|
|
net = net.to(device)
|
|
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
|
checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
|
|
# criterion = criterion.to(device)
|
|
if device == 'cuda':
|
|
net = torch.nn.DataParallel(net)
|
|
cudnn.benchmark = True
|
|
net.load_state_dict(checkpoint['net'])
|
|
|
|
test_out = validate(net, test_loader, criterion, device)
|
|
print(f"Vanilla out: {test_out}")
|
|
|
|
|
|
def validate(net, testloader, criterion, device):
|
|
net.eval()
|
|
test_loss = 0
|
|
correct = 0
|
|
total = 0
|
|
test_true = []
|
|
test_pred = []
|
|
time_cost = datetime.datetime.now()
|
|
with torch.no_grad():
|
|
for batch_idx, (data, label) in enumerate(testloader):
|
|
data, label = data.to(device), label.to(device).squeeze()
|
|
data = data.permute(0, 2, 1)
|
|
logits = net(data)
|
|
loss = criterion(logits, label)
|
|
test_loss += loss.item()
|
|
preds = logits.max(dim=1)[1]
|
|
test_true.append(label.cpu().numpy())
|
|
test_pred.append(preds.detach().cpu().numpy())
|
|
total += label.size(0)
|
|
correct += preds.eq(label).sum().item()
|
|
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
|
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
|
|
|
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
|
test_true = np.concatenate(test_true)
|
|
test_pred = np.concatenate(test_pred)
|
|
return {
|
|
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
|
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
|
|
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
|
|
"time": time_cost
|
|
}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|