PointMLP/classification_ModelNet40/test.py
2023-08-03 16:40:14 +02:00

124 lines
4.1 KiB
Python

"""python test.py --model pointMLP --msg 20220209053148-404."""
import argparse
import datetime
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from data import ModelNet40
from helper import cal_loss
from torch.utils.data import DataLoader
from utils import progress_bar
model_names = sorted(name for name in models.__dict__ if callable(models.__dict__[name]))
def parse_args():
"""Parameters"""
parser = argparse.ArgumentParser("training")
parser.add_argument(
"-c",
"--checkpoint",
type=str,
metavar="PATH",
help="path to save checkpoint (default: checkpoint)",
)
parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument("--batch_size", type=int, default=16, help="batch size in training")
parser.add_argument("--model", default="pointMLP", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=40, type=int, choices=[10, 40], help="training on ModelNet10/40")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
return parser.parse_args()
def main():
args = parse_args()
print(f"args: {args}")
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"==> Using device: {device}")
# if args.msg is None:
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
# else:
# message = "-"+args.msg
# args.checkpoint = 'checkpoints/' + args.model + message
if args.checkpoint is not None:
print(f"==> Using checkpoint: {args.checkpoint}")
print("==> Preparing data..")
test_loader = DataLoader(
ModelNet40(partition="test", num_points=args.num_points),
num_workers=4,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
# Model
print("==> Building model..")
net = models.__dict__[args.model]()
criterion = cal_loss
net = net.to(device)
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
# criterion = criterion.to(device)
if device == "cuda":
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
net.load_state_dict(checkpoint["net"])
test_out = validate(net, test_loader, criterion, device)
print(f"Vanilla out: {test_out}")
def validate(net, testloader, criterion, device):
net.eval()
test_loss = 0
correct = 0
total = 0
test_true = []
test_pred = []
time_cost = datetime.datetime.now()
with torch.no_grad():
for batch_idx, (data, label) in enumerate(testloader):
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
logits = net(data)
loss = criterion(logits, label)
test_loss += loss.item()
preds = logits.max(dim=1)[1]
test_true.append(label.cpu().numpy())
test_pred.append(preds.detach().cpu().numpy())
total += label.size(0)
correct += preds.eq(label).sum().item()
progress_bar(
batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost,
}
if __name__ == "__main__":
main()