diff --git a/test_generation.py b/test_generation.py index 5299b08..7d35c54 100644 --- a/test_generation.py +++ b/test_generation.py @@ -476,12 +476,31 @@ def generate(model, opt): # None, # None # ) - visualize_voxels( - os.path.join(str(Path(opt.eval_path).parent), 'x.png'), - gen[:64], - 1, - 0.5, - ) + # visualize_voxels( + # os.path.join(str(Path(opt.eval_path).parent), 'x.png'), + # gen[:64], + # 1, + # 0.5, + # ) + + + # visualize using matplotlib + import matplotlib.pyplot as plt + import matplotlib.cm as cm + import matplotlib + matplotlib.use('TkAgg') + for idx, pc in enumerate(gen[:64]): + print(f"Visualizing point cloud {idx}...") + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,2], cmap=cm.jet) + ax.set_aspect('equal') + ax.axis('off') + # ax.set_xlim(-1, 1) + # ax.set_ylim(-1, 1) + # ax.set_zlim(-1, 1) + plt.show() + plt.close() samples = torch.cat(samples, dim=0) ref = torch.cat(ref, dim=0)