From 6d4f2cb9d385ec02295febc5b63845a5381e8d70 Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Thu, 6 Apr 2023 15:51:19 +0200 Subject: [PATCH] feat: replace open3d viz by ghetto matplotlib viz --- .gitignore | 2 ++ demo.py | 28 ++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 56dc157..5e64bed 100644 --- a/.gitignore +++ b/.gitignore @@ -19,5 +19,7 @@ metrics/structural_losses/makefile PyMesh checkpoint +pretrained_models* + torchdiffeq/ demo/ diff --git a/demo.py b/demo.py index 8995249..4481cc1 100644 --- a/demo.py +++ b/demo.py @@ -1,4 +1,4 @@ -import open3d as o3d +# import open3d as o3d from datasets import get_datasets from args import get_args from models.networks import PointFlow @@ -47,12 +47,32 @@ def main(args): np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs) # Visualize the demo - pcl = o3d.geometry.PointCloud() + # pcl = o3d.geometry.PointCloud() + # for i in range(int(sample_pcs.shape[0])): + # print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0])) + # pts = sample_pcs[i].reshape(-1, 3) + # pcl.points = o3d.utility.Vector3dVector(pts) + # o3d.visualization.draw_geometries([pcl]) + + # Visualize the demo using matplotlib, each point cloud in a different figure + import matplotlib.pyplot as plt + import matplotlib.cm as cm + import matplotlib + matplotlib.use('TkAgg') + for i in range(int(sample_pcs.shape[0])): print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0])) pts = sample_pcs[i].reshape(-1, 3) - pcl.points = o3d.utility.Vector3dVector(pts) - o3d.visualization.draw_geometries([pcl]) + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=pts[:, 2], cmap=cm.jet) + ax.set_aspect('equal') + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + ax.set_zlim(-1, 1) + plt.show() + plt.close() + if __name__ == '__main__':