PVD/utils/visualize.py

242 lines
6.9 KiB
Python
Raw Normal View History

2021-10-19 20:54:46 +00:00
import matplotlib
2023-04-11 09:12:58 +00:00
matplotlib.use("agg")
import os
from pathlib import Path
2021-10-19 20:54:46 +00:00
import matplotlib.pyplot as plt
import numpy as np
import trimesh
2023-04-11 09:12:58 +00:00
from mpl_toolkits.mplot3d import Axes3D
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
Custom visualization
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
def export_to_pc_batch(dir, pcs, colors=None):
2021-10-19 20:54:46 +00:00
Path(dir).mkdir(parents=True, exist_ok=True)
for i, xyz in enumerate(pcs):
if colors is None:
color = None
else:
color = colors[i]
2023-04-11 09:12:58 +00:00
pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
def export_to_obj(dir, meshes, transform=lambda v, f: (v, f)):
"""
2021-10-19 20:54:46 +00:00
transform: f(vertices, faces) --> transformed (vertices, faces)
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
Path(dir).mkdir(parents=True, exist_ok=True)
for i, data in enumerate(meshes):
v, f = transform(data[0], data[1])
if len(data) > 2:
v_color = data[2]
else:
v_color = None
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
out = trimesh.exchange.obj.export_obj(mesh)
2023-04-11 09:12:58 +00:00
with open(os.path.join(dir, "sample_" + str(i) + ".obj"), "w") as f:
2021-10-19 20:54:46 +00:00
f.write(out)
f.close()
2023-04-11 09:12:58 +00:00
def export_to_obj_single(path, data, transform=lambda v, f: (v, f)):
"""
2021-10-19 20:54:46 +00:00
transform: f(vertices, faces) --> transformed (vertices, faces)
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
v, f = transform(data[0], data[1])
if len(data) > 2:
v_color = data[2]
else:
v_color = None
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
out = trimesh.exchange.obj.export_obj(mesh)
2023-04-11 09:12:58 +00:00
with open(path, "w") as f:
2021-10-19 20:54:46 +00:00
f.write(out)
f.close()
2023-04-11 09:12:58 +00:00
2021-10-19 20:54:46 +00:00
def meshwrite(filename, verts, faces, norms, colors):
2023-04-11 09:12:58 +00:00
"""Save a 3D mesh to a polygon .ply file."""
2021-10-19 20:54:46 +00:00
# Write header
2023-04-11 09:12:58 +00:00
ply_file = open(filename, "w")
2021-10-19 20:54:46 +00:00
ply_file.write("ply\n")
ply_file.write("format ascii 1.0\n")
ply_file.write("element vertex %d\n" % (verts.shape[0]))
ply_file.write("property float x\n")
ply_file.write("property float y\n")
ply_file.write("property float z\n")
ply_file.write("property float nx\n")
ply_file.write("property float ny\n")
ply_file.write("property float nz\n")
ply_file.write("property uchar red\n")
ply_file.write("property uchar green\n")
ply_file.write("property uchar blue\n")
ply_file.write("element face %d\n" % (faces.shape[0]))
ply_file.write("property list uchar int vertex_index\n")
ply_file.write("end_header\n")
# Write vertex list
for i in range(verts.shape[0]):
2023-04-11 09:12:58 +00:00
ply_file.write(
"%f %f %f %f %f %f %d %d %d\n"
% (
verts[i, 0],
verts[i, 1],
verts[i, 2],
norms[i, 0],
norms[i, 1],
norms[i, 2],
colors[i, 0],
colors[i, 1],
colors[i, 2],
)
)
2021-10-19 20:54:46 +00:00
# Write face list
for i in range(faces.shape[0]):
ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2]))
ply_file.close()
def pcwrite(filename, xyz, rgb=None):
2023-04-11 09:12:58 +00:00
"""Save a point cloud to a polygon .ply file."""
2021-10-19 20:54:46 +00:00
if rgb is None:
rgb = np.ones_like(xyz) * 128
rgb = rgb.astype(np.uint8)
# Write header
2023-04-11 09:12:58 +00:00
ply_file = open(filename, "w")
2021-10-19 20:54:46 +00:00
ply_file.write("ply\n")
ply_file.write("format ascii 1.0\n")
ply_file.write("element vertex %d\n" % (xyz.shape[0]))
ply_file.write("property float x\n")
ply_file.write("property float y\n")
ply_file.write("property float z\n")
ply_file.write("property uchar red\n")
ply_file.write("property uchar green\n")
ply_file.write("property uchar blue\n")
ply_file.write("end_header\n")
# Write vertex list
for i in range(xyz.shape[0]):
2023-04-11 09:12:58 +00:00
ply_file.write(
"%f %f %f %d %d %d\n"
% (
xyz[i, 0],
xyz[i, 1],
xyz[i, 2],
rgb[i, 0],
rgb[i, 1],
rgb[i, 2],
)
)
2021-10-19 20:54:46 +00:00
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
Matplotlib Visualization
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
2023-04-11 09:12:58 +00:00
r"""Visualizes voxel data.
2021-10-19 20:54:46 +00:00
show only first num_shown
2023-04-11 09:12:58 +00:00
"""
batch_size = voxels.shape[0]
2021-10-19 20:54:46 +00:00
voxels = voxels.squeeze(1) > threshold
num_shown = min(num_shown, batch_size)
n = int(np.sqrt(num_shown))
2023-04-11 09:12:58 +00:00
fig = plt.figure(figsize=(20, 20))
2021-10-19 20:54:46 +00:00
for idx, pc in enumerate(voxels[:num_shown]):
2023-04-11 09:12:58 +00:00
if idx >= n * n:
2021-10-19 20:54:46 +00:00
break
pc = voxels[idx]
2023-04-11 09:12:58 +00:00
ax = fig.add_subplot(n, n, idx + 1, projection="3d")
ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5)
2021-10-19 20:54:46 +00:00
ax.view_init()
2023-04-11 09:12:58 +00:00
ax.axis("off")
plt.savefig(out_file, bbox_inches="tight")
2021-10-19 20:54:46 +00:00
plt.close()
2023-04-11 09:12:58 +00:00
def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225):
r"""Visualizes point cloud data.
2021-10-19 20:54:46 +00:00
Args:
points (tensor): point data
normals (tensor): normal data (if existing)
out_file (string): output file
show (bool): whether the plot should be shown
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
# Create plot
fig = plt.figure()
ax = fig.gca(projection=Axes3D.name)
ax.scatter(points[:, 2], points[:, 0], points[:, 1])
if normals is not None:
ax.quiver(
2023-04-11 09:12:58 +00:00
points[:, 2], points[:, 0], points[:, 1], normals[:, 2], normals[:, 0], normals[:, 1], length=0.1, color="k"
2021-10-19 20:54:46 +00:00
)
2023-04-11 09:12:58 +00:00
ax.set_xlabel("Z")
ax.set_ylabel("X")
ax.set_zlabel("Y")
2021-10-19 20:54:46 +00:00
# ax.set_xlim(-0.5, 0.5)
# ax.set_ylim(-0.5, 0.5)
# ax.set_zlim(-0.5, 0.5)
ax.view_init(elev=elev, azim=azim)
if out_file is not None:
plt.savefig(out_file)
if show:
plt.show()
plt.close(fig)
2023-04-11 09:12:58 +00:00
def visualize_pointcloud_batch(
path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225
):
2021-10-19 20:54:46 +00:00
batch_size = len(pointclouds)
2023-04-11 09:12:58 +00:00
fig = plt.figure(figsize=(20, 20))
2021-10-19 20:54:46 +00:00
ncols = int(np.sqrt(batch_size))
2023-04-11 09:12:58 +00:00
nrows = max(1, (batch_size - 1) // ncols + 1)
2021-10-19 20:54:46 +00:00
for idx, pc in enumerate(pointclouds):
if vis_label:
label = categories[labels[idx].item()]
pred = categories[pred_labels[idx]]
2023-04-11 09:12:58 +00:00
colour = "g" if label == pred else "r"
2021-10-19 20:54:46 +00:00
elif target is None:
2023-04-11 09:12:58 +00:00
colour = "g"
2021-10-19 20:54:46 +00:00
else:
colour = target[idx]
pc = pc.cpu().numpy()
2023-04-11 09:12:58 +00:00
ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d")
2021-10-19 20:54:46 +00:00
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
ax.view_init(elev=elev, azim=azim)
2023-04-11 09:12:58 +00:00
ax.axis("off")
2021-10-19 20:54:46 +00:00
if vis_label:
2023-04-11 09:12:58 +00:00
ax.set_title("GT: {0}\nPred: {1}".format(label, pred))
2021-10-19 20:54:46 +00:00
plt.savefig(path)
plt.close(fig)
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
Plot stats
2023-04-11 09:12:58 +00:00
"""
2021-10-19 20:54:46 +00:00
def plot_stats(output_dir, stats, interval):
content = stats.keys()
# f = plt.figure(figsize=(20, len(content) * 5))
f, axs = plt.subplots(len(content), 1, figsize=(20, len(content) * 5))
for j, (k, v) in enumerate(stats.items()):
axs[j].plot(interval, v)
axs[j].set_ylabel(k)
2023-04-11 09:12:58 +00:00
f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight")
2021-10-19 20:54:46 +00:00
plt.close(f)