fix tensor device issue

https://github.com/nv-tlabs/LION/issues/31#issue-1627736930
the original code index cpu tensor with cuda tensor (works in torch 1.10.2),
may fail in other torch version?
This commit is contained in:
xzeng 2023-03-16 12:44:57 -04:00
parent 531d6956b5
commit 4d6af1d8b9

View file

@ -52,7 +52,7 @@ def compute_NLL_metric(gen_pcs, ref_pcs, device, writer=None, output_name='', ba
for k in metrics.keys():
sorted, indices = torch.sort(metrics[k])
worse_ten, worse_score = indices[-10:], sorted[-10:]
worse_ten, worse_score = indices[-10:].cpu(), sorted[-10:].cpu()
titles = 'nll/worst-%s-%s' % (k, tag)
subtitles = [['ori', 'gen-%s=%.2fx1e-2' %
(k, worse_score[j]*1e2)] for j in range(len(worse_score))]