diff --git a/utils/eval_helper.py b/utils/eval_helper.py index 95585fb..d817179 100644 --- a/utils/eval_helper.py +++ b/utils/eval_helper.py @@ -52,7 +52,7 @@ def compute_NLL_metric(gen_pcs, ref_pcs, device, writer=None, output_name='', ba for k in metrics.keys(): sorted, indices = torch.sort(metrics[k]) - worse_ten, worse_score = indices[-10:], sorted[-10:] + worse_ten, worse_score = indices[-10:].cpu(), sorted[-10:].cpu() titles = 'nll/worst-%s-%s' % (k, tag) subtitles = [['ori', 'gen-%s=%.2fx1e-2' % (k, worse_score[j]*1e2)] for j in range(len(worse_score))]