upload a test code
This commit is contained in:
parent
a59fef3634
commit
742a6c4afc
109
classification_ModelNet40/test.py
Normal file
109
classification_ModelNet40/test.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
python test.py --model pointMLP --msg 20220209053148-404
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import datetime
|
||||
import torch
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
from torch.utils.data import DataLoader
|
||||
import models as models
|
||||
from utils import progress_bar, IOStream
|
||||
from data import ModelNet40
|
||||
import sklearn.metrics as metrics
|
||||
from helper import cal_loss
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if callable(models.__dict__[name]))
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parameters"""
|
||||
parser = argparse.ArgumentParser('training')
|
||||
parser.add_argument('-c', '--checkpoint', type=str, metavar='PATH',
|
||||
help='path to save checkpoint (default: checkpoint)')
|
||||
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='batch size in training')
|
||||
parser.add_argument('--model', default='pointMLP', help='model name [default: pointnet_cls]')
|
||||
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(f"args: {args}")
|
||||
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
else:
|
||||
device = 'cpu'
|
||||
print(f"==> Using device: {device}")
|
||||
if args.msg is None:
|
||||
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||
else:
|
||||
message = "-"+args.msg
|
||||
args.checkpoint = 'checkpoints/' + args.model + message
|
||||
|
||||
print('==> Preparing data..')
|
||||
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
|
||||
batch_size=args.batch_size, shuffle=False, drop_last=False)
|
||||
# Model
|
||||
print('==> Building model..')
|
||||
net = models.__dict__[args.model]()
|
||||
criterion = cal_loss
|
||||
net = net.to(device)
|
||||
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
||||
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
# criterion = criterion.to(device)
|
||||
if device == 'cuda':
|
||||
net = torch.nn.DataParallel(net)
|
||||
cudnn.benchmark = True
|
||||
net.load_state_dict(checkpoint['net'])
|
||||
|
||||
test_out = validate(net, test_loader, criterion, device)
|
||||
print(f"Vanilla out: {test_out}")
|
||||
|
||||
|
||||
def validate(net, testloader, criterion, device):
|
||||
net.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
test_true = []
|
||||
test_pred = []
|
||||
time_cost = datetime.datetime.now()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, label) in enumerate(testloader):
|
||||
data, label = data.to(device), label.to(device).squeeze()
|
||||
data = data.permute(0, 2, 1)
|
||||
logits = net(data)
|
||||
loss = criterion(logits, label)
|
||||
test_loss += loss.item()
|
||||
preds = logits.max(dim=1)[1]
|
||||
test_true.append(label.cpu().numpy())
|
||||
test_pred.append(preds.detach().cpu().numpy())
|
||||
total += label.size(0)
|
||||
correct += preds.eq(label).sum().item()
|
||||
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
||||
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
||||
|
||||
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
|
||||
test_true = np.concatenate(test_true)
|
||||
test_pred = np.concatenate(test_pred)
|
||||
return {
|
||||
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
|
||||
"acc": float("%.3f" % (100. * metrics.accuracy_score(test_true, test_pred))),
|
||||
"acc_avg": float("%.3f" % (100. * metrics.balanced_accuracy_score(test_true, test_pred))),
|
||||
"time": time_cost
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in a new issue