223 lines
6.7 KiB
Python
223 lines
6.7 KiB
Python
import matplotlib
|
|
matplotlib.use('agg')
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
|
import numpy as np
|
|
import os
|
|
import trimesh
|
|
from pathlib import Path
|
|
|
|
'''
|
|
Custom visualization
|
|
'''
|
|
|
|
def export_to_pc_batch(dir, pcs, colors=None):
|
|
|
|
Path(dir).mkdir(parents=True, exist_ok=True)
|
|
for i, xyz in enumerate(pcs):
|
|
if colors is None:
|
|
color = None
|
|
else:
|
|
color = colors[i]
|
|
pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color)
|
|
|
|
|
|
def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
|
|
'''
|
|
transform: f(vertices, faces) --> transformed (vertices, faces)
|
|
'''
|
|
Path(dir).mkdir(parents=True, exist_ok=True)
|
|
for i, data in enumerate(meshes):
|
|
v, f = transform(data[0], data[1])
|
|
if len(data) > 2:
|
|
v_color = data[2]
|
|
else:
|
|
v_color = None
|
|
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
|
out = trimesh.exchange.obj.export_obj(mesh)
|
|
with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f:
|
|
f.write(out)
|
|
f.close()
|
|
|
|
def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
|
|
'''
|
|
transform: f(vertices, faces) --> transformed (vertices, faces)
|
|
'''
|
|
v, f = transform(data[0], data[1])
|
|
if len(data) > 2:
|
|
v_color = data[2]
|
|
else:
|
|
v_color = None
|
|
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
|
out = trimesh.exchange.obj.export_obj(mesh)
|
|
with open(path, 'w') as f:
|
|
f.write(out)
|
|
f.close()
|
|
|
|
def meshwrite(filename, verts, faces, norms, colors):
|
|
"""Save a 3D mesh to a polygon .ply file.
|
|
"""
|
|
# Write header
|
|
ply_file = open(filename, 'w')
|
|
ply_file.write("ply\n")
|
|
ply_file.write("format ascii 1.0\n")
|
|
ply_file.write("element vertex %d\n" % (verts.shape[0]))
|
|
ply_file.write("property float x\n")
|
|
ply_file.write("property float y\n")
|
|
ply_file.write("property float z\n")
|
|
ply_file.write("property float nx\n")
|
|
ply_file.write("property float ny\n")
|
|
ply_file.write("property float nz\n")
|
|
ply_file.write("property uchar red\n")
|
|
ply_file.write("property uchar green\n")
|
|
ply_file.write("property uchar blue\n")
|
|
ply_file.write("element face %d\n" % (faces.shape[0]))
|
|
ply_file.write("property list uchar int vertex_index\n")
|
|
ply_file.write("end_header\n")
|
|
|
|
# Write vertex list
|
|
for i in range(verts.shape[0]):
|
|
ply_file.write("%f %f %f %f %f %f %d %d %d\n" % (
|
|
verts[i, 0], verts[i, 1], verts[i, 2],
|
|
norms[i, 0], norms[i, 1], norms[i, 2],
|
|
colors[i, 0], colors[i, 1], colors[i, 2],
|
|
))
|
|
|
|
# Write face list
|
|
for i in range(faces.shape[0]):
|
|
ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2]))
|
|
|
|
ply_file.close()
|
|
|
|
|
|
def pcwrite(filename, xyz, rgb=None):
|
|
"""Save a point cloud to a polygon .ply file.
|
|
"""
|
|
if rgb is None:
|
|
rgb = np.ones_like(xyz) * 128
|
|
rgb = rgb.astype(np.uint8)
|
|
|
|
# Write header
|
|
ply_file = open(filename, 'w')
|
|
ply_file.write("ply\n")
|
|
ply_file.write("format ascii 1.0\n")
|
|
ply_file.write("element vertex %d\n" % (xyz.shape[0]))
|
|
ply_file.write("property float x\n")
|
|
ply_file.write("property float y\n")
|
|
ply_file.write("property float z\n")
|
|
ply_file.write("property uchar red\n")
|
|
ply_file.write("property uchar green\n")
|
|
ply_file.write("property uchar blue\n")
|
|
ply_file.write("end_header\n")
|
|
|
|
# Write vertex list
|
|
for i in range(xyz.shape[0]):
|
|
ply_file.write("%f %f %f %d %d %d\n" % (
|
|
xyz[i, 0], xyz[i, 1], xyz[i, 2],
|
|
rgb[i, 0], rgb[i, 1], rgb[i, 2],
|
|
))
|
|
|
|
'''
|
|
Matplotlib Visualization
|
|
'''
|
|
|
|
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
|
|
r''' Visualizes voxel data.
|
|
show only first num_shown
|
|
'''
|
|
batch_size =voxels.shape[0]
|
|
voxels = voxels.squeeze(1) > threshold
|
|
|
|
num_shown = min(num_shown, batch_size)
|
|
|
|
n = int(np.sqrt(num_shown))
|
|
fig = plt.figure(figsize=(20,20))
|
|
|
|
for idx, pc in enumerate(voxels[:num_shown]):
|
|
if idx >= n*n:
|
|
break
|
|
pc = voxels[idx]
|
|
ax = fig.add_subplot(n, n, idx + 1, projection='3d')
|
|
ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5)
|
|
ax.view_init()
|
|
ax.axis('off')
|
|
plt.savefig(out_file, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
def visualize_pointcloud(points, normals=None,
|
|
out_file=None, show=False, elev=30, azim=225):
|
|
r''' Visualizes point cloud data.
|
|
Args:
|
|
points (tensor): point data
|
|
normals (tensor): normal data (if existing)
|
|
out_file (string): output file
|
|
show (bool): whether the plot should be shown
|
|
'''
|
|
# Create plot
|
|
fig = plt.figure()
|
|
ax = fig.gca(projection=Axes3D.name)
|
|
ax.scatter(points[:, 2], points[:, 0], points[:, 1])
|
|
if normals is not None:
|
|
ax.quiver(
|
|
points[:, 2], points[:, 0], points[:, 1],
|
|
normals[:, 2], normals[:, 0], normals[:, 1],
|
|
length=0.1, color='k'
|
|
)
|
|
ax.set_xlabel('Z')
|
|
ax.set_ylabel('X')
|
|
ax.set_zlabel('Y')
|
|
# ax.set_xlim(-0.5, 0.5)
|
|
# ax.set_ylim(-0.5, 0.5)
|
|
# ax.set_zlim(-0.5, 0.5)
|
|
ax.view_init(elev=elev, azim=azim)
|
|
if out_file is not None:
|
|
plt.savefig(out_file)
|
|
if show:
|
|
plt.show()
|
|
plt.close(fig)
|
|
|
|
|
|
def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225):
|
|
batch_size = len(pointclouds)
|
|
fig = plt.figure(figsize=(20,20))
|
|
|
|
ncols = int(np.sqrt(batch_size))
|
|
nrows = max(1, (batch_size-1) // ncols+1)
|
|
for idx, pc in enumerate(pointclouds):
|
|
if vis_label:
|
|
label = categories[labels[idx].item()]
|
|
pred = categories[pred_labels[idx]]
|
|
colour = 'g' if label == pred else 'r'
|
|
elif target is None:
|
|
|
|
colour = 'g'
|
|
else:
|
|
colour = target[idx]
|
|
pc = pc.cpu().numpy()
|
|
ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d')
|
|
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
|
|
ax.view_init(elev=elev, azim=azim)
|
|
ax.axis('off')
|
|
if vis_label:
|
|
ax.set_title('GT: {0}\nPred: {1}'.format(label, pred))
|
|
|
|
plt.savefig(path)
|
|
plt.close(fig)
|
|
|
|
|
|
'''
|
|
Plot stats
|
|
'''
|
|
|
|
def plot_stats(output_dir, stats, interval):
|
|
content = stats.keys()
|
|
# f = plt.figure(figsize=(20, len(content) * 5))
|
|
f, axs = plt.subplots(len(content), 1, figsize=(20, len(content) * 5))
|
|
for j, (k, v) in enumerate(stats.items()):
|
|
axs[j].plot(interval, v)
|
|
axs[j].set_ylabel(k)
|
|
|
|
f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight')
|
|
plt.close(f)
|