feat: replace open3d viz by ghetto matplotlib viz
This commit is contained in:
parent
8b3bceffd7
commit
6d4f2cb9d3
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -19,5 +19,7 @@ metrics/structural_losses/makefile
|
||||||
PyMesh
|
PyMesh
|
||||||
checkpoint
|
checkpoint
|
||||||
|
|
||||||
|
pretrained_models*
|
||||||
|
|
||||||
torchdiffeq/
|
torchdiffeq/
|
||||||
demo/
|
demo/
|
||||||
|
|
28
demo.py
28
demo.py
|
@ -1,4 +1,4 @@
|
||||||
import open3d as o3d
|
# import open3d as o3d
|
||||||
from datasets import get_datasets
|
from datasets import get_datasets
|
||||||
from args import get_args
|
from args import get_args
|
||||||
from models.networks import PointFlow
|
from models.networks import PointFlow
|
||||||
|
@ -47,12 +47,32 @@ def main(args):
|
||||||
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
|
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
|
||||||
|
|
||||||
# Visualize the demo
|
# Visualize the demo
|
||||||
pcl = o3d.geometry.PointCloud()
|
# pcl = o3d.geometry.PointCloud()
|
||||||
|
# for i in range(int(sample_pcs.shape[0])):
|
||||||
|
# print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
||||||
|
# pts = sample_pcs[i].reshape(-1, 3)
|
||||||
|
# pcl.points = o3d.utility.Vector3dVector(pts)
|
||||||
|
# o3d.visualization.draw_geometries([pcl])
|
||||||
|
|
||||||
|
# Visualize the demo using matplotlib, each point cloud in a different figure
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.cm as cm
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
|
|
||||||
for i in range(int(sample_pcs.shape[0])):
|
for i in range(int(sample_pcs.shape[0])):
|
||||||
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
|
||||||
pts = sample_pcs[i].reshape(-1, 3)
|
pts = sample_pcs[i].reshape(-1, 3)
|
||||||
pcl.points = o3d.utility.Vector3dVector(pts)
|
fig = plt.figure()
|
||||||
o3d.visualization.draw_geometries([pcl])
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
|
ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=pts[:, 2], cmap=cm.jet)
|
||||||
|
ax.set_aspect('equal')
|
||||||
|
ax.set_xlim(-1, 1)
|
||||||
|
ax.set_ylim(-1, 1)
|
||||||
|
ax.set_zlim(-1, 1)
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in a new issue