fix tensor device issue
https://github.com/nv-tlabs/LION/issues/31#issue-1627736930 the original code index cpu tensor with cuda tensor (works in torch 1.10.2), may fail in other torch version?
This commit is contained in:
parent
531d6956b5
commit
4d6af1d8b9
|
@ -52,7 +52,7 @@ def compute_NLL_metric(gen_pcs, ref_pcs, device, writer=None, output_name='', ba
|
|||
|
||||
for k in metrics.keys():
|
||||
sorted, indices = torch.sort(metrics[k])
|
||||
worse_ten, worse_score = indices[-10:], sorted[-10:]
|
||||
worse_ten, worse_score = indices[-10:].cpu(), sorted[-10:].cpu()
|
||||
titles = 'nll/worst-%s-%s' % (k, tag)
|
||||
subtitles = [['ori', 'gen-%s=%.2fx1e-2' %
|
||||
(k, worse_score[j]*1e2)] for j in range(len(worse_score))]
|
||||
|
|
Loading…
Reference in a new issue