# import open3d as o3d from datasets import get_datasets from args import get_args from models.networks import PointFlow import os import torch import numpy as np import torch.nn as nn def main(args): model = PointFlow(args) def _transform_(m): return nn.DataParallel(m) model = model.cuda() model.multi_gpu_wrapper(_transform_) print("Resume Path:%s" % args.resume_checkpoint) checkpoint = torch.load(args.resume_checkpoint) model.load_state_dict(checkpoint) model.eval() _, te_dataset = get_datasets(args) if args.resume_dataset_mean is not None and args.resume_dataset_std is not None: mean = np.load(args.resume_dataset_mean) std = np.load(args.resume_dataset_std) te_dataset.renormalize(mean, std) ds_mean = torch.from_numpy(te_dataset.all_points_mean).cuda() ds_std = torch.from_numpy(te_dataset.all_points_std).cuda() all_sample = [] with torch.no_grad(): for i in range(0, args.num_sample_shapes, args.batch_size): B = len(range(i, min(i + args.batch_size, args.num_sample_shapes))) N = args.num_sample_points _, out_pc = model.sample(B, N) out_pc = out_pc * ds_std + ds_mean all_sample.append(out_pc) sample_pcs = torch.cat(all_sample, dim=0).cpu().detach().numpy() print("Generation sample size:(%s, %s, %s)" % sample_pcs.shape) # Save the generative output os.makedirs("demo", exist_ok=True) np.save(os.path.join("demo", "model_out_smp.npy"), sample_pcs) # Visualize the demo # pcl = o3d.geometry.PointCloud() # for i in range(int(sample_pcs.shape[0])): # print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0])) # pts = sample_pcs[i].reshape(-1, 3) # pcl.points = o3d.utility.Vector3dVector(pts) # o3d.visualization.draw_geometries([pcl]) # Visualize the demo using matplotlib, each point cloud in a different figure import matplotlib.pyplot as plt import matplotlib.cm as cm import matplotlib matplotlib.use('TkAgg') for i in range(int(sample_pcs.shape[0])): print("Visualizing: %03d/%03d" % (i, sample_pcs.shape[0])) pts = sample_pcs[i].reshape(-1, 3) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2], c=pts[:, 2], cmap=cm.jet) ax.set_aspect('equal') ax.set_xlim(-1, 1) ax.set_ylim(-1, 1) ax.set_zlim(-1, 1) plt.show() plt.close() if __name__ == '__main__': args = get_args() main(args)