PointFlow/demo.py
2023-04-06 15:51:19 +02:00

81 lines
2.6 KiB
Python

# import open3d as o3d
from datasets import get_datasets
from args import get_args
from models.networks import PointFlow
import os
import torch
import numpy as np
import torch.nn as nn
def main(args):
model = PointFlow(args)
def _transform_(m):
return nn.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
print("Resume Path:%s" % args.resume_checkpoint)
checkpoint = torch.load(args.resume_checkpoint)
model.load_state_dict(checkpoint)
model.eval()
_, te_dataset = get_datasets(args)
if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
mean = np.load(args.resume_dataset_mean)
std = np.load(args.resume_dataset_std)
te_dataset.renormalize(mean, std)
ds_mean = torch.from_numpy(te_dataset.all_points_mean).cuda()
ds_std = torch.from_numpy(te_dataset.all_points_std).cuda()
all_sample = []
with torch.no_grad():
for i in range(0, args.num_sample_shapes, args.batch_size):
B = len(range(i, min(i + args.batch_size, args.num_sample_shapes)))
N = args.num_sample_points
_, out_pc = model.sample(B, N)
out_pc = out_pc * ds_std + ds_mean
all_sample.append(out_pc)
sample_pcs = torch.cat(all_sample, dim=0).cpu().detach().numpy()
print("Generation sample size:(%s, %s, %s)" % sample_pcs.shape)
# Save the generative output
os.makedirs("demo", exist_ok=True)
np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs)
# Visualize the demo
# pcl = o3d.geometry.PointCloud()
# for i in range(int(sample_pcs.shape[0])):
# print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
# pts = sample_pcs[i].reshape(-1, 3)
# pcl.points = o3d.utility.Vector3dVector(pts)
# o3d.visualization.draw_geometries([pcl])
# Visualize the demo using matplotlib, each point cloud in a different figure
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
matplotlib.use('TkAgg')
for i in range(int(sample_pcs.shape[0])):
print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0]))
pts = sample_pcs[i].reshape(-1, 3)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=pts[:, 2], cmap=cm.jet)
ax.set_aspect('equal')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
plt.show()
plt.close()
if __name__ == '__main__':
args = get_args()
main(args)