This commit is contained in:
Linqi (Alex) Zhou 2021-10-19 13:54:46 -07:00
commit 2f6aa752a6
115 changed files with 20577 additions and 0 deletions

10
.gitignore vendored Normal file
View file

@ -0,0 +1,10 @@
.idea
data
output
__pycache__
*.png
*.net
*.npy
*.npz
eval_model
eval_data

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "metrics/ChamferDistancePytorch"]
path = metrics/ChamferDistancePytorch
url = https://github.com/ThibaultGROUEIX/ChamferDistancePytorch

59
README.md Normal file
View file

@ -0,0 +1,59 @@
# Shape Generation and Completion Through Point-Voxel Diffusion
[Project]() | [Paper]()
Implementation of
## Pretrained Models
Pretrained models can be accessed [here](https://www.dropbox.com/s/a3xydf594fzaokl/cifar10_pretrained.rar?dl=0).
## Requirements:
Make sure the following environments are installed.
```
python==3.6
pytorch==1.4.0
torchvision==0.5.0
cudatoolkit==10.1
matplotlib==2.2.5
tqdm==4.32.1
open3d==0.9.0
```
The code was tested on Unbuntu with Titan RTX.
## Training on CIFAR-10:
```bash
$ python train_cifar.py
```
Please refer to the python file for optimal training parameters.
## Results
Some generative results are as follows.
<p float="left">
<img src="example/cifar_gen.png" width="300"/>
<img src="example/lsun_gen.png" width="300"/>
</p>
## Reference
```
@inproceedings{han2020joint,
title={Joint Training of Variational Auto-Encoder and Latent Energy-Based Model},
author={Han, Tian and Nijkamp, Erik and Zhou, Linqi and Pang, Bo and Zhu, Song-Chun and Wu, Ying Nian},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={7978--7987},
year={2020}
}
```
## Acknowledgement
For any questions related to codes and experiment setting, please contact Linqi (Alex) Zhou (alexzhou907@gmail.com). For questions related to model and algorithm in the paper, please contact Tian Han (hantian@ucla.edu). Thanks to [@Tian Han ](https://github.com/hthth0801?tab=repositories) and [@Erik Njikamp](https://github.com/enijkamp) for their colloboration and guidance.

0
datasets/__init__.py Normal file
View file

213
datasets/partnet.py Normal file
View file

@ -0,0 +1,213 @@
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import os
import json
import random
import trimesh
import csv
from plyfile import PlyData, PlyElement
from glob import glob
def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image
:param points: (n, 3) range(-1, 1)
:return: binary image
"""
img = []
for i in range(3):
canvas = np.zeros((resolution, resolution))
axis = [0, 1, 2]
axis.remove(i)
proj_points = (points[:, axis] + 1) / 2 * resolution
proj_points = proj_points.astype(np.int)
canvas[proj_points[:, 0], proj_points[:, 1]] = 1
img.append(canvas)
img = np.concatenate(img, axis=1)
return img
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
PlyData([el], text=text).write(f)
def rotate_point_cloud(points, transformation_mat):
new_points = np.dot(transformation_mat, points.T).T
return new_points
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
""" align 3depn shapes to shapenet coordinates"""
# angle = math.radians(angle_deg)
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
# rot_m = rot_m.to_matrix()
rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00],
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
[-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]])
new_points = rotate_point_cloud(points, rot_m)
return new_points
def downsample_point_cloud(points, n_pts):
"""downsample points by random choice
:param points: (n, 3)
:param n_pts: int
:return:
"""
p_idx = random.choices(list(range(points.shape[0])), k=n_pts)
return points[p_idx]
def upsample_point_cloud(points, n_pts):
"""upsample points by random choice
:param points: (n, 3)
:param n_pts: int, > n
:return:
"""
p_idx = random.choices(list(range(points.shape[0])), k=n_pts - points.shape[0])
dup_points = points[p_idx]
points = np.concatenate([points, dup_points], axis=0)
return points
def sample_point_cloud_by_n(points, n_pts):
"""resample point cloud to given number of points"""
if n_pts > points.shape[0]:
return upsample_point_cloud(points, n_pts)
elif n_pts < points.shape[0]:
return downsample_point_cloud(points, n_pts)
else:
return points
def collect_data_id(split_dir, classname, phase):
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
if not os.path.exists(filename):
raise ValueError("Invalid filepath: {}".format(filename))
all_ids = []
with open(filename, 'r') as fp:
info = json.load(fp)
for item in info:
all_ids.append(item["anno_id"])
return all_ids
class GANdatasetPartNet(Dataset):
def __init__(self, phase, data_root, category, n_pts):
super(GANdatasetPartNet, self).__init__()
if phase == "validation":
phase = "val"
self.phase = phase
self.aug = phase == "train"
self.data_root = data_root
shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase)
self.shape_names = []
for name in shape_names:
path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name)
if os.path.exists(path):
self.shape_names.append(name)
self.n_pts = n_pts
self.raw_n_pts = self.n_pts // 2
self.rng = random.Random(1234)
@staticmethod
def load_point_cloud(path):
pc = trimesh.load(path)
pc = pc.vertices / 2.0 # scale to unit sphere
return pc
@staticmethod
def read_point_cloud_part_label(path):
with open(path, 'r') as fp:
labels = fp.readlines()
labels = np.array([int(x) for x in labels])
return labels
def random_rm_parts(self, raw_pc, part_labels):
part_ids = sorted(np.unique(part_labels).tolist())
if self.phase == "train":
random.shuffle(part_ids)
n_part_keep = random.randint(1, max(1, len(part_ids) - 1))
else:
self.rng.shuffle(part_ids)
n_part_keep = self.rng.randint(1, max(1, len(part_ids) - 1))
part_ids_keep = part_ids[:n_part_keep]
point_idx = []
for i in part_ids_keep:
point_idx.extend(np.where(part_labels == i)[0].tolist())
raw_pc = raw_pc[point_idx]
return raw_pc, n_part_keep
def __getitem__(self, index):
raw_shape_name = self.shape_names[index]
raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply')
raw_pc = self.load_point_cloud(raw_ply_path)
raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt')
part_labels = self.read_point_cloud_part_label(raw_label_path)
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
real_shape_name = self.shape_names[index]
real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply')
real_pc = self.load_point_cloud(real_ply_path)
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name,
'n_part_keep': n_part_keep, 'idx': index}
def __len__(self):
return len(self.shape_names)
if __name__ == '__main__':
data_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetPointCloud'
data_raw_root = '/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc'
pc_dataroot = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k'
sn_root = '/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2'
classes = 'car'
npoints = 2048
# from datasets.shapenet_data_pc import ShapeNet15kPointClouds
# pc_ds = ShapeNet15kPointClouds(root_dir=pc_dataroot,
# categories=[classes], split='train',
# tr_sample_size=npoints,
# te_sample_size=npoints,
# scale=1.,
# normalize_per_shape=False,
# normalize_std_per_axis=False,
# random_subsample=True)
train_ds = GANdatasetPartNet('test', pc_dataroot, data_raw_root, classes, npoints, np.array([0,0,0]),
np.array([1, 1, 1]))
d1 = train_ds[0]
real = d1['real']
raw = d1['raw']
m, s = d1['m'], d1['s']
x = (torch.cat([raw, real], dim=-1) * s + m).transpose(0,1)
write_ply(x.numpy(), 'x.ply')
pass

View file

@ -0,0 +1,268 @@
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
import random
import open3d as o3d
import numpy as np
import torch.nn.functional as F
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class Uniform15KPC(Dataset):
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
te_sample_size=10000, split='train', scale=1.,
normalize_per_shape=False, box_per_shape=False,
random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None, all_points_std=None,
input_dim=3, use_mask=False):
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.input_dim = input_dim
self.use_mask = use_mask
self.box_per_shape = box_per_shape
if use_mask:
self.mask_transform = PointCloudMasks(radius=5, elev=5, azim=90)
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s" % sub_path)
continue
all_mids = []
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
try:
point_cloud = np.load(obj_fname) # (15k, 3)
except:
continue
assert point_cloud.shape[0] == 15000
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
elif self.box_per_shape:
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim)
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
if self.box_per_shape:
self.all_points = self.all_points - 0.5
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points))
print("Min number of points: (train)%d (test)%d"
% (self.tr_sample_size, self.te_sample_size))
assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx):
if self.normalize_per_shape or self.box_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
tr_out = self.train_points[idx]
if self.random_subsample:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
te_out = self.test_points[idx]
if self.random_subsample:
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
else:
te_idxs = np.arange(self.te_sample_size)
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
out = {
'idx': idx,
'train_points': tr_out,
'test_points': te_out,
'mean': m, 'std': s, 'cate_idx': cate_idx,
'sid': sid, 'mid': mid
}
if self.use_mask:
# masked = torch.from_numpy(self.mask_transform(self.all_points[idx]))
# ss = min(masked.shape[0], self.in_tr_sample_size//2)
# masked = masked[:ss]
#
# tr_mask = torch.ones_like(masked)
# masked = torch.cat([masked, torch.zeros(self.in_tr_sample_size - ss, 3)],dim=0)#F.pad(masked, (self.in_tr_sample_size-masked.shape[0], 0), "constant", 0)
#
# tr_mask = torch.cat([tr_mask, torch.zeros(self.in_tr_sample_size- ss, 3)],dim=0)#F.pad(tr_mask, (self.in_tr_sample_size-tr_mask.shape[0], 0), "constant", 0)
# out['train_points_masked'] = masked
# out['train_masks'] = tr_mask
tr_mask = self.mask_transform(tr_out)
out['train_masks'] = tr_mask
return out
class ShapeNet15kPointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False, box_per_shape=False,
random_subsample=False,
all_points_mean=None, all_points_std=None,
use_mask=False):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test', 'val']
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
self.cates = categories
if 'all' in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__(
root_dir, self.synset_ids,
tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size,
split=split, scale=scale,
normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3, use_mask=use_mask)
class PointCloudMasks(object):
'''
render a view then save mask
'''
def __init__(self, radius : float=10, elev: float =45, azim:float=315, ):
self.radius = radius
self.elev = elev
self.azim = azim
def __call__(self, points):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
camera = [self.radius * np.sin(90-self.elev) * np.cos(self.azim),
self.radius * np.cos(90 - self.elev),
self.radius * np.sin(90 - self.elev) * np.sin(self.azim),
]
# camera = [0,self.radius,0]
_, pt_map = pcd.hidden_point_removal(camera, self.radius)
mask = torch.zeros_like(points)
mask[pt_map] = 1
return mask #points[pt_map]
####################################################################################

View file

@ -0,0 +1,257 @@
import warnings
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import open3d as o3d
import os
import numpy as np
import hashlib
import torch
import matplotlib.pyplot as plt
synset_to_label = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()}
def _convert_categories(categories):
assert categories is not None, 'List of categories cannot be empty!'
if not (c in synset_to_label.keys() + label_to_synset.keys()
for c in categories):
warnings.warn('Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable.')
synsets = [label_to_synset[c] if c in label_to_synset.keys()
else c for c in categories]
return synsets
class ShapeNet_Multiview_Points(Dataset):
def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val',
npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False):
self.root = Path(root_views)
self.split = split
self.get_image = get_image
params = {
'cat': categories,
'npoints': npoints,
'sv_samples': sv_samples,
}
params = tuple(sorted(pair for pair in params.items()))
self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest())
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.paths = []
self.synset_idxs = []
self.synsets = _convert_categories(categories)
self.labels = [synset_to_label[s] for s in self.synsets]
self.npoints = npoints
self.sv_samples = sv_samples
self.all_points = []
self.all_points_sv = []
# loops through desired classes
for i in range(len(self.synsets)):
syn = self.synsets[i]
class_target = self.root / syn
if not class_target.exists():
raise ValueError('Class {0} ({1}) was not found at location {2}.'.format(
syn, self.labels[i], str(class_target)))
sub_path_pc = os.path.join(root_pc, syn, split)
if not os.path.isdir(sub_path_pc):
print("Directory missing : %s" % sub_path_pc)
continue
self.all_mids = []
self.imgs = []
for x in os.listdir(sub_path_pc):
if not x.endswith('.npy'):
continue
self.all_mids.append(os.path.join(split, x[:-len('.npy')]))
for mid in tqdm(self.all_mids):
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz'))
if len(cams_pths) < 20:
continue
point_cloud = np.load(obj_fname)
sv_points_group = []
img_path_group = []
(self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True)
success = True
for i, cp in enumerate(cams_pths):
cp = str(cp)
vp = cp.split('cam_params')[0] + 'depth.png'
depth_minmax_pth = cp.split('_cam_params')[0] + '.npy'
cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) )
cam_params = np.load(cp)
extr = cam_params['extr']
intr = cam_params['intr']
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
try:
sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth)
img_path_group.append(vp)
sv_points_group.append(sv_point_cloud)
except Exception as e:
print(e)
success=False
break
if not success:
continue
self.all_points_sv.append(np.stack(sv_points_group, axis=0))
self.all_points.append(point_cloud)
self.imgs.append(img_path_group)
self.all_points = np.stack(self.all_points, axis=0)
self.all_points_sv = np.stack(self.all_points_sv, axis=0)
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3)
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:,:10000]
self.test_points = self.all_points[:,10000:]
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
def get_pc_stats(self, idx):
return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)
def __len__(self):
"""Returns the length of the dataset. """
return len(self.all_points)
def __getitem__(self, index):
tr_out = self.train_points[index]
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
tr_out = tr_out[tr_idxs, :]
gt_points = self.test_points[index][:self.npoints]
m, s = self.get_pc_stats(index)
sv_points = self.all_points_sv[index]
idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False)
data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1)
masks = torch.zeros_like(data)
masks[:,:idxs.shape[0]] = 1
res = {'train_points': torch.from_numpy(tr_out).float(),
'test_points': torch.from_numpy(gt_points).float(),
'sv_points': data,
'masks': masks,
'std': s, 'mean': m,
'idx': index,
'name':self.all_mids[index]
}
if self.split != 'train' and self.get_image:
img_lst = []
for n in range(self.all_points_sv.shape[1]):
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]
img_lst.append(img)
img = torch.stack(img_lst, dim=0)
res['image'] = img
return res
def _render(self, cache_path, depth_pth, depth_minmax_pth):
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
#
# os.remove(cache_path)
if os.path.exists(cache_path):
data = np.load(cache_path)
else:
data, depth = self.transform(depth_pth, depth_minmax_pth)
assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0])
data = data[np.random.choice(data.shape[0], 600, replace=False)]
np.save(cache_path, data)
return data
class DepthToSingleViewPoints(object):
'''
render a view then save mask
'''
def __init__(self, cam_ext, cam_int):
self.cam_ext = cam_ext.reshape(4,4)
self.cam_int = cam_int.reshape(3,3)
def __call__(self, depth_pth, depth_minmax_pth):
depth_minmax = np.load(depth_minmax_pth)
depth_img = plt.imread(depth_pth)[...,0]
mask = np.where(depth_img == 0, -1.0, 1.0)
depth_img = 1 - depth_img
depth_img = (depth_img * (np.max(depth_minmax) - np.min(depth_minmax)) + np.min(depth_minmax)) * mask
intr = o3d.camera.PinholeCameraIntrinsic(depth_img.shape[0], depth_img.shape[1],
self.cam_int[0, 0], self.cam_int[1, 1], self.cam_int[0,2],
self.cam_int[1,2])
depth_im = o3d.geometry.Image(depth_img.astype(np.float32, copy=False))
# rgbd_im = o3d.geometry.RGBDImage.create_from_color_and_depth(color_im, depth_im)
pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_im, intr, self.cam_ext, depth_scale=1.)
pc = np.asarray(pcd.points)
return pc, depth_img
def __repr__(self):
return 'MeshToMaskedVoxel_'+str(self.radius)+str(self.resolution)+str(self.elev )+str(self.azim)+str(self.img_size )

1
metrics/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
StructuralLosses

View file

@ -0,0 +1 @@
*__pycache__*

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 ThibaultGROUEIX
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,101 @@
# Pytorch Chamfer Distance.
Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations.
NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly.
- [x] F - Score
### CUDA VERSION
- [x] JIT compilation
- [x] Supports multi-gpu
- [x] 2D point clouds.
- [x] 3D point clouds.
- [x] 5D point clouds.
- [x] Contiguous() safe.
### Python Version
- [x] Supports any dimension
### Usage
```python
import torch, chamfer3D.dist_chamfer_3D, fscore
chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
points1 = torch.rand(32, 1000, 3).cuda()
points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda()
dist1, dist2, idx1, idx2 = chamLoss(points1, points2)
f_score, precision, recall = fscore.fscore(dist1, dist2)
```
### Add it to your project as a submodule
```shell
git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
```
### Benchmark: [forward + backward] pass
- [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4
- [x] p1 : 32 x 2000 x dim
- [x] p2 : 32 x 1000 x dim
| *Timing (sec * 1000)* | 2D | 3D | 5D |
| ---------- | -------- | ------- | ------- |
| **Cuda Compiled** | **1.2** | 1.4 |1.8 |
| **Cuda JIT** | 1.3 | **1.4** |**1.5** |
| **Python** | 37 | 37 | 37 |
| *Memory (MB)* | 2D | 3D | 5D |
| ---------- | -------- | ------- | ------- |
| **Cuda Compiled** | 529 | 529 | 549 |
| **Cuda JIT** | **520** | **529** |**549** |
| **Python** | 2495 | 2495 | 2495 |
### What is the chamfer distance ?
[Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning
### Aknowledgment
Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu).
JIT cool trick from [Christian Diller](https://github.com/chrdiller)
### Troubleshoot
- `Undefined symbol: Zxxxxxxxxxxxxxxxxx `:
--> Fix: Make sure to `import torch` before you `import chamfer`.
--> Use pytorch.version >= 1.1.0
- [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167)
```shell
wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
```
#### TODO:
* Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions

View file

@ -0,0 +1,182 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*2];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*2;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*2+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*2+0];
float y1=xyz[(i*n+j)*2+1];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&2);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*2+2]-x1;
float y2=buf[k*2+3]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*2+4]-x1;
float y2=buf[k*2+5]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*2+6]-x1;
float y2=buf[k*2+7]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*2+2]-x1;
float y2=buf[k*2+3]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*2+4]-x1;
float y2=buf[k*2+5]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*2+6]-x1;
float y2=buf[k*2+7]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*2+0];
float y1=xyz1[(i*n+j)*2+1];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*2+0];
float y2=xyz2[(i*m+j2)*2+1];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*2+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*2+1]),g*(y1-y2));
atomicAdd(&(grad_xyz2[(i*m+j2)*2+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*2+1]),-(g*(y1-y2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,73 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_2D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 2D")
from torch.utils.cpp_extension import load
chamfer_2D = load(name="chamfer_2D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
])
print("Loaded JIT 2D CUDA chamfer distance")
else:
import chamfer_2D
print("Loaded compiled 2D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_2DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, _ = xyz1.size()
_, m, _ = xyz2.size()
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_2D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_2DDist(nn.Module):
def __init__(self):
super(chamfer_2DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_2DFunction.apply(input1, input2)

View file

@ -0,0 +1,16 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_2D',
ext_modules=[
CUDAExtension('chamfer_2D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
]),
],
extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,196 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*3+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*3+0];
float y1=xyz[(i*n+j)*3+1];
float z1=xyz[(i*n+j)*3+2];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&3);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*3+3]-x1;
float y2=buf[k*3+4]-y1;
float z2=buf[k*3+5]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*3+6]-x1;
float y2=buf[k*3+7]-y1;
float z2=buf[k*3+8]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*3+9]-x1;
float y2=buf[k*3+10]-y1;
float z2=buf[k*3+11]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*3+3]-x1;
float y2=buf[k*3+4]-y1;
float z2=buf[k*3+5]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*3+6]-x1;
float y2=buf[k*3+7]-y1;
float z2=buf[k*3+8]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*3+9]-x1;
float y2=buf[k*3+10]-y1;
float z2=buf[k*3+11]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*3+0];
float y1=xyz1[(i*n+j)*3+1];
float z1=xyz1[(i*n+j)*3+2];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*3+0];
float y2=xyz2[(i*m+j2)*3+1];
float z2=xyz2[(i*m+j2)*3+2];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,77 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_3D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 3D")
from torch.utils.cpp_extension import load
chamfer_3D = load(name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
],
extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'],)
print("Loaded JIT 3D CUDA chamfer distance")
else:
import chamfer_3D
print("Loaded compiled 3D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_3DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, _ = xyz1.size()
_, m, _ = xyz2.size()
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_3D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_3DDist(nn.Module):
def __init__(self):
super(chamfer_3DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_3DFunction.apply(input1, input2)

View file

@ -0,0 +1,16 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_3D',
ext_modules=[
CUDAExtension('chamfer_3D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
]),
],
extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,223 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=2048;
__shared__ float buf[batch*5];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*5;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*5+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*5+0];
float y1=xyz[(i*n+j)*5+1];
float r1=xyz[(i*n+j)*5+2];
float g1=xyz[(i*n+j)*5+3];
float b1=xyz[(i*n+j)*5+4];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&5);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*5+0]-x1;
float y2=buf[k*5+1]-y1;
float r2=buf[k*5+2]-r1;
float g2=buf[k*5+3]-g1;
float b2=buf[k*5+4]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*5+5]-x1;
float y2=buf[k*5+6]-y1;
float r2=buf[k*5+7]-r1;
float g2=buf[k*5+8]-g1;
float b2=buf[k*5+9]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*5+10]-x1;
float y2=buf[k*5+11]-y1;
float r2=buf[k*5+12]-r1;
float g2=buf[k*5+13]-g1;
float b2=buf[k*5+14]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*5+15]-x1;
float y2=buf[k*5+16]-y1;
float r2=buf[k*5+17]-r1;
float g2=buf[k*5+18]-g1;
float b2=buf[k*5+19]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*5+0]-x1;
float y2=buf[k*5+1]-y1;
float r2=buf[k*5+2]-r1;
float g2=buf[k*5+3]-g1;
float b2=buf[k*5+4]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*5+5]-x1;
float y2=buf[k*5+6]-y1;
float r2=buf[k*5+7]-r1;
float g2=buf[k*5+8]-g1;
float b2=buf[k*5+9]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*5+10]-x1;
float y2=buf[k*5+11]-y1;
float r2=buf[k*5+12]-r1;
float g2=buf[k*5+13]-g1;
float b2=buf[k*5+14]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*5+15]-x1;
float y2=buf[k*5+16]-y1;
float r2=buf[k*5+17]-r1;
float g2=buf[k*5+18]-g1;
float b2=buf[k*5+19]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*5+0]-x1;
float y2=buf[k*5+1]-y1;
float r2=buf[k*5+2]-r1;
float g2=buf[k*5+3]-g1;
float b2=buf[k*5+4]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*5+0];
float y1=xyz1[(i*n+j)*5+1];
float r1=xyz1[(i*n+j)*5+2];
float g1=xyz1[(i*n+j)*5+3];
float b1=xyz1[(i*n+j)*5+4];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*5+0];
float y2=xyz2[(i*m+j2)*5+1];
float r2=xyz2[(i*m+j2)*5+2];
float g2=xyz2[(i*m+j2)*5+3];
float b2=xyz2[(i*m+j2)*5+4];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*5+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+1]),g*(y1-y2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+2]),g*(r1-r2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+3]),g*(g1-g2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+4]),g*(b1-b2));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+1]),-(g*(y1-y2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+2]),-(g*(r1-r2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+3]),-(g*(g1-g2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+4]),-(g*(b1-b2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,75 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_5D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 5D")
from torch.utils.cpp_extension import load
chamfer_5D = load(name="chamfer_5D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]),
])
print("Loaded JIT 5D CUDA chamfer distance")
else:
import chamfer_5D
print("Loaded compiled 5D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_5DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, _ = xyz1.size()
_, m, _ = xyz2.size()
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_5D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_5DDist(nn.Module):
def __init__(self):
super(chamfer_5DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_5DFunction.apply(input1, input2)

View file

@ -0,0 +1,16 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_5D',
ext_modules=[
CUDAExtension('chamfer_5D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']),
]),
],
extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,40 @@
import torch
def pairwise_dist(x, y):
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)
P = rx.t() + ry - 2 * zz
return P
def NN_loss(x, y, dim=0):
dist = pairwise_dist(x, y)
values, indices = dist.min(dim=dim)
return values.mean()
def distChamfer(a, b):
"""
:param a: Pointclouds Batch x nul_points x dim
:param b: Pointclouds Batch x nul_points x dim
:return:
-closest point on b of points from a
-closest point on a of points from b
-idx of closest point on b of points from a
-idx of closest point on a of points from b
Works for pointcloud of any dimension
"""
x, y = a.double(), b.double()
bs, num_points_x, points_dim = x.size()
bs, num_points_y, points_dim = y.size()
xx = torch.pow(x, 2).sum(2)
yy = torch.pow(y, 2).sum(2)
zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
P = rx.transpose(2, 1) + ry - 2 * zz
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()

View file

@ -0,0 +1,17 @@
import torch
def fscore(dist1, dist2, threshold=0.001):
"""
Calculates the F-score between two point clouds with the corresponding threshold value.
:param dist1: Batch, N-Points
:param dist2: Batch, N-Points
:param th: float
:return: fscore, precision, recall
"""
# NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly.
precision_1 = torch.mean((dist1 < threshold).float(), dim=1)
precision_2 = torch.mean((dist2 < threshold).float(), dim=1)
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
fscore[torch.isnan(fscore)] = 0
return fscore, precision_1, precision_2

View file

@ -0,0 +1,69 @@
import torch, time
import chamfer2D.dist_chamfer_2D
import chamfer3D.dist_chamfer_3D
import chamfer5D.dist_chamfer_5D
import chamfer_python
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
from torch.autograd import Variable
from fscore import fscore
def test_chamfer(distChamfer, dim):
points1 = torch.rand(4, 100, dim).cuda()
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
loss = torch.sum(dist1)
loss.backward()
mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2)
d1 = (dist1 - mydist1) ** 2
d2 = (dist2 - mydist2) ** 2
assert (
torch.mean(d1) + torch.mean(d2) < 0.00000001
), "chamfer cuda and chamfer normal are not giving the same results"
xd1 = idx1 - myidx1
xd2 = idx2 - myidx2
assert (
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
), "chamfer cuda and chamfer normal are not giving the same results"
print(f"fscore :", fscore(dist1, dist2))
print("Unit test passed")
def timings(distChamfer, dim):
p1 = torch.rand(32, 2000, dim).cuda()
p2 = torch.rand(32, 1000, dim).cuda()
print("Timings : Start CUDA version")
start = time.time()
num_it = 100
for i in range(num_it):
points1 = Variable(p1, requires_grad=True)
points2 = Variable(p2)
mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2)
loss = torch.sum(mydist1)
loss.backward()
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
print("Timings : Start Pythonic version")
start = time.time()
for i in range(num_it):
points1 = Variable(p1, requires_grad=True)
points2 = Variable(p2)
mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2)
loss = torch.sum(mydist1)
loss.backward()
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
dims = [2,3,5]
for i,cham in enumerate([cham2D, cham3D, cham5D]):
print(f"testing Chamfer {dims[i]}D")
test_chamfer(cham, dims[i])
timings(cham, dims[i])

5
metrics/PyTorchEMD/.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
__pycache__
build
dist
emd_ext.egg-info
*.so

View file

@ -0,0 +1,31 @@
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
## Dependency
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
## Usage
First compile using
python setup.py install
Then, copy the lib file out to the main directory,
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
Then, you can use it by simply
from emd import earth_mover_distance
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
Check `test_emd_loss.py` for example.
## Author
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
## License
MIT

View file

View file

@ -0,0 +1,29 @@
#ifndef _EMD
#define _EMD
#include <vector>
#include <torch/extension.h>
//CUDA declarations
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2);
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
}
#endif

View file

@ -0,0 +1,400 @@
/**********************************
* Original Author: Haoqiang Fan
* Modified by: Kaichun Mo
*********************************/
#ifndef _EMD_KERNEL
#define _EMD_KERNEL
#include <cmath>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
#include <THC/THC.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/********************************
* Forward kernel for approxmatch
*********************************/
template<typename scalar_t>
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
scalar_t multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ scalar_t buf[Block*4];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
match[i*n*m+j]=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x)
remainL[j]=multiL;
for (int j=threadIdx.x;j<m;j+=blockDim.x)
remainR[j]=multiR;
__syncthreads();
for (int j=7;j>=-2;j--){
scalar_t level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
scalar_t suml=1e-9f;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0];
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1];
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+0]=x2;
buf[l*4+1]=y2;
buf[l*4+2]=z2;
buf[l*4+3]=remainR[l0+l];
}
__syncthreads();
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*4+0];
scalar_t y2=buf[l*4+1];
scalar_t z2=buf[l*4+2];
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
scalar_t w=__expf(d)*buf[l*4+3];
suml+=w;
}
__syncthreads();
}
if (k<n)
ratioL[k]=remainL[k]/suml;
}
__syncthreads();
for (int l0=0;l0<m;l0+=blockDim.x){
int l=l0+threadIdx.x;
scalar_t x2=0,y2=0,z2=0;
if (l<m){
x2=xyz2[i*m*3+l*3+0];
y2=xyz2[i*m*3+l*3+1];
z2=xyz2[i*m*3+l*3+2];
}
scalar_t sumr=0;
for (int k0=0;k0<n;k0+=Block){
int kend=min(n,k0+Block)-k0;
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
buf[k*4+3]=ratioL[k0+k];
}
__syncthreads();
for (int k=0;k<kend;k++){
scalar_t x1=buf[k*4+0];
scalar_t y1=buf[k*4+1];
scalar_t z1=buf[k*4+2];
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
sumr+=w;
}
__syncthreads();
}
if (l<m){
sumr*=remainR[l];
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}
}
__syncthreads();
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
scalar_t suml=0;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+3]=ratioR[l0+l];
}
__syncthreads();
scalar_t rl=ratioL[k];
if (k<n){
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*4+0];
scalar_t y2=buf[l*4+1];
scalar_t z2=buf[l*4+2];
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
match[i*n*m+(l0+l)*n+k]+=w;
suml+=w;
}
}
__syncthreads();
}
if (k<n)
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}
__syncthreads();
}
}
}
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
//}
/* ApproxMatch forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
Output:
match: (B, N2, N1)
*/
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return match;
}
/********************************
* Forward kernel for matchcost
*********************************/
template<typename scalar_t>
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
__shared__ scalar_t allsum[512];
const int Block=1024;
__shared__ scalar_t buf[Block*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
scalar_t subsum=0;
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
buf[l]=xyz2[i*m*3+l0*3+l];
__syncthreads();
if (k<n){
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*3+0];
scalar_t y2=buf[l*3+1];
scalar_t z2=buf[l*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
subsum+=d*match[i*n*m+(l0+l)*n+k];
}
}
__syncthreads();
}
}
allsum[threadIdx.x]=subsum;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
}
}
if (threadIdx.x==0)
out[i]=allsum[0];
__syncthreads();
}
}
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
//}
/* MatchCost forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
cost: (B)
*/
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
__shared__ scalar_t sum_grad[256*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
int kbeg=m*blockIdx.y/gridDim.y;
int kend=m*(blockIdx.y+1)/gridDim.y;
for (int k=kbeg;k<kend;k++){
scalar_t x2=xyz2[(i*m+k)*3+0];
scalar_t y2=xyz2[(i*m+k)*3+1];
scalar_t z2=xyz2[(i*m+k)*3+2];
scalar_t subsumx=0,subsumy=0,subsumz=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x){
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
scalar_t d=match[i*n*m+k*n+j]*2;
subsumx+=x1*d;
subsumy+=y1*d;
subsumz+=z1*d;
}
sum_grad[threadIdx.x*3+0]=subsumx;
sum_grad[threadIdx.x*3+1]=subsumy;
sum_grad[threadIdx.x*3+2]=subsumz;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
int j1=threadIdx.x;
int j2=threadIdx.x+j;
if ((j1&j)==0 && j2<blockDim.x){
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
}
}
if (threadIdx.x==0){
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
}
__syncthreads();
}
}
}
/********************************
* matchcostgrad1 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int l=threadIdx.x;l<n;l+=blockDim.x){
scalar_t x1=xyz1[i*n*3+l*3+0];
scalar_t y1=xyz1[i*n*3+l*3+1];
scalar_t z1=xyz1[i*n*3+l*3+2];
scalar_t dx=0,dy=0,dz=0;
for (int k=0;k<m;k++){
scalar_t x2=xyz2[i*m*3+k*3+0];
scalar_t y2=xyz2[i*m*3+k*3+1];
scalar_t z2=xyz2[i*m*3+k*3+2];
scalar_t d=match[i*n*m+k*n+l]*2;
dx+=(x1-x2)*d;
dy+=(y1-y2)*d;
dz+=(z1-z2)*d;
}
grad1[i*n*3+l*3+0]=dx*grad_cost[i];
grad1[i*n*3+l*3+1]=dy*grad_cost[i];
grad1[i*n*3+l*3+2]=dz*grad_cost[i];
}
}
}
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/* MatchCost backward interface
Input:
grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
grad1: (B, N1, 3)
grad2: (B, N2, 3)
*/
std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return std::vector<at::Tensor>({grad1, grad2});
}
#endif

47
metrics/PyTorchEMD/emd.py Normal file
View file

@ -0,0 +1,47 @@
import torch
import emd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
cost = cost / xyz1.shape[1]
return cost

View file

@ -0,0 +1,27 @@
"""Setup extension
Notes:
If extra_compile_args is provided, you need to provide different instances for different extensions.
Refer to https://github.com/pytorch/pytorch/issues/20169
"""
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
ext_modules=[
CUDAExtension(
name='emd_cuda',
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
),
],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,44 @@
import torch
import numpy as np
import time
from emd import earth_mover_distance
# gt
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist.backward()
print(p1.grad)
print(p2.grad)
# emd
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
d = earth_mover_distance(p1, p2, transpose=False)
print(d)
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)

0
metrics/__init__.py Normal file
View file

View file

@ -0,0 +1,322 @@
import torch
import numpy as np
import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
from metrics.ChamferDistancePytorch.fscore import fscore
from tqdm import tqdm
cham3D = chamfer_3DDist()
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
x, y = a, b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
diag_ind = torch.arange(0, num_points).to(a).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P.min(1)[0], P.min(2)[0]
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
cd_lst = []
emd_lst = []
fs_lst = []
iterator = range(0, N_sample, batch_size)
for b_start in iterator:
b_end = min(N_sample, b_start + batch_size)
sample_batch = sample_pcs[b_start:b_end]
ref_batch = ref_pcs[b_start:b_end]
dl, dr, _, _ = cham3D(sample_batch.cuda(), ref_batch.cuda())
fs = fscore(dl, dr)[0].cpu()
fs_lst.append(fs)
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
emd_batch = EMD(sample_batch.cuda(), ref_batch.cuda(), transpose=False)
emd_lst.append(emd_batch)
if reduced:
cd = torch.cat(cd_lst).mean()
emd = torch.cat(emd_lst).mean()
else:
cd = torch.cat(cd_lst)
emd = torch.cat(emd_lst)
fs_lst = torch.cat(fs_lst).mean()
results = {
'MMD-CD': cd,
'MMD-EMD': emd,
'fscore': fs_lst
}
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
all_emd = []
iterator = range(N_sample)
for sample_b_start in tqdm(iterator):
sample_batch = sample_pcs[sample_b_start]
cd_lst = []
emd_lst = []
for ref_b_start in range(0, N_ref, batch_size):
ref_b_end = min(N_ref, ref_b_start + batch_size)
ref_batch = ref_pcs[ref_b_start:ref_b_end]
batch_size_ref = ref_batch.size(0)
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
sample_batch_exp = sample_batch_exp.contiguous()
dl, dr, _, _ = cham3D(sample_batch_exp.cuda(), ref_batch.cuda())
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1).detach().cpu())
emd_batch = EMD(sample_batch_exp.cuda(), ref_batch.cuda(), transpose=False)
emd_lst.append(emd_batch.view(1, -1).detach().cpu())
cd_lst = torch.cat(cd_lst, dim=1)
emd_lst = torch.cat(emd_lst, dim=1)
all_cd.append(cd_lst)
all_emd.append(emd_lst)
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref
return all_cd, all_emd
# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py
def knn(Mxx, Mxy, Myy, k, sqrt=False):
n0 = Mxx.size(0)
n1 = Myy.size(0)
label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
if sqrt:
M = M.abs().sqrt()
INFINITY = float('inf')
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
count = torch.zeros(n0 + n1).to(Mxx)
for i in range(0, k):
count = count + label.index_select(0, idx[i])
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
s = {
'tp': (pred * label).sum(),
'fp': (pred * (1 - label)).sum(),
'fn': ((1 - pred) * label).sum(),
'tn': ((1 - pred) * (1 - label)).sum(),
}
s.update({
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
'acc': torch.eq(label, pred).float().mean(),
})
return s
def lgan_mmd_cov(all_dist):
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
min_val, _ = torch.min(all_dist, dim=0)
mmd = min_val.mean()
mmd_smp = min_val_fromsmp.mean()
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
cov = torch.tensor(cov).to(all_dist)
return {
'lgan_mmd': mmd,
'lgan_cov': cov,
'lgan_mmd_smp': mmd_smp,
}
def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
results = {}
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({
"%s-CD" % k: v for k, v in res_cd.items()
})
res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({
"%s-EMD" % k: v for k, v in res_emd.items()
})
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
# 1-NN results
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
return results
#######################################################
# JSD : from https://github.com/optas/latent_3d_points
#######################################################
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
that is placed in the unit-cube.
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
"""
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
spacing = 1.0 / float(resolution - 1)
for i in range(resolution):
for j in range(resolution):
for k in range(resolution):
grid[i, j, k, 0] = i * spacing - 0.5
grid[i, j, k, 1] = j * spacing - 0.5
grid[i, j, k, 2] = k * spacing - 0.5
if clip_sphere:
grid = grid.reshape(-1, 3)
grid = grid[norm(grid, axis=1) <= 0.5]
return grid, spacing
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
"""Computes the JSD between two sets of point-clouds, as introduced in the paper
```Learning Representations And Generative Models For 3D Point Clouds```.
Args:
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
resolution: (int) grid-resolution. Affects granularity of measurements.
"""
in_unit_sphere = True
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):
"""Given a collection of point-clouds, estimate the entropy of the random variables
corresponding to occupancy-grid activation patterns.
Inputs:
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
grid_resolution (int) size of occupancy grid that will be used.
"""
epsilon = 10e-4
bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit cube.')
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit sphere.')
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3)
grid_counters = np.zeros(len(grid_coordinates))
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
for pc in pclouds:
_, indices = nn.kneighbors(pc)
indices = np.squeeze(indices)
for i in indices:
grid_counters[i] += 1
indices = np.unique(indices)
for i in indices:
grid_bernoulli_rvars[i] += 1
acc_entropy = 0.0
n = float(len(pclouds))
for g in grid_bernoulli_rvars:
if g > 0:
p = float(g) / n
acc_entropy += entropy([p, 1.0 - p])
return acc_entropy / len(grid_counters), grid_counters
def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.')
if len(P) != len(Q):
raise ValueError('Non equal size.')
P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q)
e1 = entropy(P_, base=2)
e2 = entropy(Q_, base=2)
e_sum = entropy((P_ + Q_) / 2.0, base=2)
res = e_sum - ((e1 + e2) / 2.0)
res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.')
return res
def _jsdiv(P, Q):
"""another way of computing JSD"""
def _kldiv(A, B):
a = A.copy()
b = B.copy()
idx = np.logical_and(a > 0, b > 0)
a = a[idx]
b = b[idx]
return np.sum([v for v in a * np.log2(a / b)])
P_ = P / np.sum(P)
Q_ = Q / np.sum(Q)
M = 0.5 * (P_ + Q_)
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
if __name__ == "__main__":
B, N = 2, 10
x = torch.rand(B, N, 3)
y = torch.rand(B, N, 3)
min_l, min_r = distChamfer(x.cuda(), y.cuda())
print(min_l.shape)
print(min_r.shape)
l_dist = min_l.mean().cpu().detach().item()
r_dist = min_r.mean().cpu().detach().item()
print(l_dist, r_dist)
emd_batch = EMD(x.cuda(), y.cuda(), False)
print(emd_batch.shape)
print(emd_batch.mean().detach().item())
jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy())
print(jsd)

0
model/__init__.py Normal file
View file

253
model/pvcnn_completion.py Normal file
View file

@ -0,0 +1,253 @@
import functools
import torch.nn as nn
import torch
import numpy as np
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
c = 0
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = k % 2 == 0 and k > 0 and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels = [], []
c = 0
for conv_configs, sa_configs in sa_blocks:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se and not attention, with_se_relu=True,
normalize=normalize, eps=eps)
if c == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
c += 1
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0])
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
)
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = c % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0])
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels
class PVCNN2Base(nn.Module):
def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1,
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
super().__init__()
assert extra_feature_channels >= 0
self.embed_dim = embed_dim
self.sv_points = sv_points
self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.sa_layers = nn.ModuleList(sa_layers)
self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)
# only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points,
with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes],
classifier=True, dim=2, width_multiplier=width_multiplier)
self.classifier = nn.Sequential(*layers)
self.embedf = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.LeakyReLU(0.1, inplace=True),
nn.Linear(embed_dim, embed_dim),
)
def get_timestep_embedding(self, timesteps, device):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = self.embed_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
# emb = tf.range(num_embeddings, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :]
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if self.embed_dim % 2 == 1: # zero pad
# emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1)
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim])
return emb
def forward(self, inputs, t):
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
# inputs : [B, in_channels + S, N]
coords, features = inputs[:, :3, :].contiguous(), inputs
coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i == 0:
features, coords, temb = sa_blocks ((features, coords, temb))
else:
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None:
features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers):
jump_coords = coords_list[-1 - fp_idx]
fump_feats = in_features_list[-1-fp_idx]
# if fp_idx == len(self.fp_layers) - 1:
# jump_coords = jump_coords[:,:,self.sv_points:]
# fump_feats = fump_feats[:,:,self.sv_points:]
features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb))
return self.classifier(features)

247
model/pvcnn_generation.py Normal file
View file

@ -0,0 +1,247 @@
import functools
import torch.nn as nn
import torch
import numpy as np
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
c = 0
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = k % 2 == 0 and k > 0 and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels = [], []
c = 0
for conv_configs, sa_configs in sa_blocks:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
if c == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
c += 1
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0])
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
)
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0])
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels
class PVCNN2Base(nn.Module):
def __init__(self, num_classes, embed_dim, use_att, dropout=0.1,
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
super().__init__()
assert extra_feature_channels >= 0
self.embed_dim = embed_dim
self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.sa_layers = nn.ModuleList(sa_layers)
self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)
# only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5
classifier=True, dim=2, width_multiplier=width_multiplier)
self.classifier = nn.Sequential(*layers)
self.embedf = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.LeakyReLU(0.1, inplace=True),
nn.Linear(embed_dim, embed_dim),
)
def get_timestep_embedding(self, timesteps, device):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = self.embed_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
# emb = tf.range(num_embeddings, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :]
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if self.embed_dim % 2 == 1: # zero pad
# emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1)
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim])
return emb
def forward(self, inputs, t):
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
# inputs : [B, in_channels + S, N]
coords, features = inputs[:, :3, :].contiguous(), inputs
coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i == 0:
features, coords, temb = sa_blocks ((features, coords, temb))
else:
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None:
features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers):
features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb))
return self.classifier(features)

8
modules/__init__.py Normal file
View file

@ -0,0 +1,8 @@
from modules.ball_query import BallQuery
from modules.frustum import FrustumPointNetLoss
from modules.loss import KLLoss
from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
from modules.pvconv import PVConv, Attention, Swish, PVConvReLU
from modules.se import SE3d
from modules.shared_mlp import SharedMLP
from modules.voxelization import Voxelization

34
modules/ball_query.py Normal file
View file

@ -0,0 +1,34 @@
import torch
import torch.nn as nn
import modules.functional as F
__all__ = ['BallQuery']
class BallQuery(nn.Module):
def __init__(self, radius, num_neighbors, include_coordinates=True):
super().__init__()
self.radius = radius
self.num_neighbors = num_neighbors
self.include_coordinates = include_coordinates
def forward(self, points_coords, centers_coords, temb, points_features=None):
points_coords = points_coords.contiguous()
centers_coords = centers_coords.contiguous()
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
if points_features is None:
assert self.include_coordinates, 'No Features For Grouping'
neighbor_features = neighbor_coordinates
else:
neighbor_features = F.grouping(points_features, neighbor_indices)
if self.include_coordinates:
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
return neighbor_features, F.grouping(temb, neighbor_indices)
def extra_repr(self):
return 'radius={}, num_neighbors={}{}'.format(
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')

138
modules/frustum.py Normal file
View file

@ -0,0 +1,138 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import modules.functional as PF
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
class FrustumPointNetLoss(nn.Module):
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
super().__init__()
self.box_loss_weight = box_loss_weight
self.corners_loss_weight = corners_loss_weight
self.heading_residual_loss_weight = heading_residual_loss_weight
self.size_residual_loss_weight = size_residual_loss_weight
self.num_heading_angle_bins = num_heading_angle_bins
self.num_size_templates = num_size_templates
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
self.register_buffer(
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
)
def forward(self, inputs, targets):
mask_logits = inputs['mask_logits'] # (B, 2, N)
center_reg = inputs['center_reg'] # (B, 3)
center = inputs['center'] # (B, 3)
heading_scores = inputs['heading_scores'] # (B, NH)
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
heading_residuals = inputs['heading_residuals'] # (B, NH)
size_scores = inputs['size_scores'] # (B, NS)
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
size_residuals = inputs['size_residuals'] # (B, NS, 3)
mask_logits_target = targets['mask_logits'] # (B, N)
center_target = targets['center'] # (B, 3)
heading_bin_id_target = targets['heading_bin_id'] # (B, )
heading_residual_target = targets['heading_residual'] # (B, )
size_template_id_target = targets['size_template_id'] # (B, )
size_residual_target = targets['size_residual'] # (B, 3)
batch_size = center.size(0)
batch_id = torch.arange(batch_size, device=center.device)
# Basic Classification and Regression losses
mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
size_loss = F.cross_entropy(size_scores, size_template_id_target)
center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
# Refinement losses for size/heading
heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
heading_residual_normalized_loss = PF.huber_loss(
heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
)
size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
size_residual_normalized_loss = PF.huber_loss(
torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
)
# Bounding box losses
heading = (heading_residuals[batch_id, heading_bin_id_target]
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
size = (size_residuals[batch_id, size_template_id_target]
+ self.size_templates[size_template_id_target]) # (B, 3)
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
sizes=size_target, with_flip=True) # (B, 3, 8)
corners_loss = PF.huber_loss(torch.min(
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
), delta=1.0)
# Summing up
loss = mask_loss + self.box_loss_weight * (
center_loss + center_reg_loss + heading_loss + size_loss
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
+ self.size_residual_loss_weight * size_residual_normalized_loss
+ self.corners_loss_weight * corners_loss
)
return loss
def get_box_corners_3d(centers, headings, sizes, with_flip=False):
"""
:param centers: coords of box centers, FloatTensor[N, 3]
:param headings: heading angles, FloatTensor[N, ]
:param sizes: box sizes, FloatTensor[N, 3]
:param with_flip: bool, whether to return flipped box (headings + np.pi)
:return:
coords of box corners, FloatTensor[N, 3, 8]
NOTE: corner points are in counter clockwise order, e.g.,
2--1
3--0 5
7--4
"""
l = sizes[:, 0] # (N,)
w = sizes[:, 1] # (N,)
h = sizes[:, 2] # (N,)
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
c = torch.cos(headings) # (N,)
s = torch.sin(headings) # (N,)
o = torch.ones_like(headings) # (N,)
z = torch.zeros_like(headings) # (N,)
centers = centers.unsqueeze(-1) # (B, 3, 1)
corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
if with_flip:
R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
else:
return torch.matmul(R, corners) + centers
# centers = centers.unsqueeze(1) # (B, 1, 3)
# corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
# RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
# if with_flip:
# RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
# return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
# else:
# return torch.matmul(corners, RT) + centers # (N, 8, 3)
# corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
# R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
# corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
# corners = corners.transpose(1, 2) # (N, 8, 3)

View file

@ -0,0 +1,7 @@
from modules.functional.ball_query import ball_query
from modules.functional.devoxelization import trilinear_devoxelize
from modules.functional.grouping import grouping
from modules.functional.interpolatation import nearest_neighbor_interpolate
from modules.functional.loss import kl_loss, huber_loss
from modules.functional.sampling import gather, furthest_point_sample, logits_mask
from modules.functional.voxelization import avg_voxelize

View file

@ -0,0 +1,26 @@
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
_backend = load(name='_pvcnn_backend',
extra_cflags=['-O3', '-std=c++17'],
extra_cuda_cflags=['--compiler-bindir=/usr/bin/gcc-8'],
sources=[os.path.join(_src_path,'src', f) for f in [
'ball_query/ball_query.cpp',
'ball_query/ball_query.cu',
'grouping/grouping.cpp',
'grouping/grouping.cu',
'interpolate/neighbor_interpolate.cpp',
'interpolate/neighbor_interpolate.cu',
'interpolate/trilinear_devox.cpp',
'interpolate/trilinear_devox.cu',
'sampling/sampling.cpp',
'sampling/sampling.cu',
'voxelization/vox.cpp',
'voxelization/vox.cu',
'bindings.cpp',
]]
)
__all__ = ['_backend']

View file

@ -0,0 +1,19 @@
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['ball_query']
def ball_query(centers_coords, points_coords, radius, num_neighbors):
"""
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param radius: float, radius of ball query
:param num_neighbors: int, maximum number of neighbors
:return:
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
"""
centers_coords = centers_coords.contiguous()
points_coords = points_coords.contiguous()
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)

View file

@ -0,0 +1,42 @@
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['trilinear_devoxelize']
class TrilinearDevoxelization(Function):
@staticmethod
def forward(ctx, features, coords, resolution, is_training=True):
"""
:param ctx:
:param coords: the coordinates of points, FloatTensor[B, 3, N]
:param features: FloatTensor[B, C, R, R, R]
:param resolution: int, the voxel resolution
:param is_training: bool, training mode
:return:
FloatTensor[B, C, N]
"""
B, C = features.shape[:2]
features = features.contiguous().view(B, C, -1)
coords = coords.contiguous()
outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features)
if is_training:
ctx.save_for_backward(inds, wgts)
ctx.r = resolution
return outs
@staticmethod
def backward(ctx, grad_output):
"""
:param ctx:
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
:return:
gradient of inputs, FloatTensor[B, C, R, R, R]
"""
inds, wgts = ctx.saved_tensors
grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r)
return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None
trilinear_devoxelize = TrilinearDevoxelization.apply

View file

@ -0,0 +1,31 @@
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['grouping']
class Grouping(Function):
@staticmethod
def forward(ctx, features, indices):
"""
:param ctx:
:param features: features of points, FloatTensor[B, C, N]
:param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
:return:
grouped_features: grouped features, FloatTensor[B, C, M, U]
"""
features = features.contiguous()
indices = indices.contiguous()
ctx.save_for_backward(indices)
ctx.num_points = features.size(-1)
return _backend.grouping_forward(features, indices)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None
grouping = Grouping.apply

View file

@ -0,0 +1,38 @@
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['nearest_neighbor_interpolate']
class NeighborInterpolation(Function):
@staticmethod
def forward(ctx, points_coords, centers_coords, centers_features):
"""
:param ctx:
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param centers_features: features of centers, FloatTensor[B, C, M]
:return:
points_features: features of points, FloatTensor[B, C, N]
"""
centers_coords = centers_coords.contiguous()
points_coords = points_coords.contiguous()
centers_features = centers_features.contiguous()
points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
points_coords, centers_coords, centers_features
)
ctx.save_for_backward(indices, weights)
ctx.num_centers = centers_coords.size(-1)
return points_features
@staticmethod
def backward(ctx, grad_output):
indices, weights = ctx.saved_tensors
grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
grad_output.contiguous(), indices, weights, ctx.num_centers
)
return None, None, grad_centers_features
nearest_neighbor_interpolate = NeighborInterpolation.apply

View file

@ -0,0 +1,17 @@
import torch
import torch.nn.functional as F
__all__ = ['kl_loss', 'huber_loss']
def kl_loss(x, y):
x = F.softmax(x.detach(), dim=1)
y = F.log_softmax(y, dim=1)
return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1))
def huber_loss(error, delta):
abs_error = torch.abs(error)
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
return torch.mean(losses)

View file

@ -0,0 +1,84 @@
import numpy as np
import torch
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
class Gather(Function):
@staticmethod
def forward(ctx, features, indices):
"""
Gather
:param ctx:
:param features: features of points, FloatTensor[B, C, N]
:param indices: centers' indices in points, IntTensor[b, m]
:return:
centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
"""
features = features.contiguous()
indices = indices.int().contiguous()
ctx.save_for_backward(indices)
ctx.num_points = features.size(-1)
return _backend.gather_features_forward(features, indices)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None
gather = Gather.apply
def furthest_point_sample(coords, num_samples):
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance to the sampled point set
:param coords: coordinates of points, FloatTensor[B, 3, N]
:param num_samples: int, M
:return:
centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
"""
coords = coords.contiguous()
indices = _backend.furthest_point_sampling(coords, num_samples)
return gather(coords, indices)
def logits_mask(coords, logits, num_points_per_object):
"""
Use logits to sample points
:param coords: coords of points, FloatTensor[B, 3, N]
:param logits: binary classification logits, FloatTensor[B, 2, N]
:param num_points_per_object: M, #points per object after masking, int
:return:
selected_coords: FloatTensor[B, 3, M]
masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
mask: mask to select points, BoolTensor[B, N]
"""
batch_size, _, num_points = coords.shape
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
torch.ones_like(num_candidates)).float() # [B, C]
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
for i in range(batch_size):
current_mask = mask[i] # [N]
current_candidates = current_mask.nonzero().view(-1)
current_num_candidates = current_candidates.numel()
if current_num_candidates >= num_points_per_object:
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
selected_indices[i] = current_candidates[choices]
elif current_num_candidates > 0:
choices = np.concatenate([
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
])
np.random.shuffle(choices)
selected_indices[i] = current_candidates[choices]
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
return selected_coords, masked_coords_mean, mask

View file

@ -0,0 +1,30 @@
#include "ball_query.hpp"
#include "ball_query.cuh"
#include "../utils.hpp"
at::Tensor ball_query_forward(at::Tensor centers_coords,
at::Tensor points_coords, const float radius,
const int num_neighbors) {
CHECK_CUDA(centers_coords);
CHECK_CUDA(points_coords);
CHECK_CONTIGUOUS(centers_coords);
CHECK_CONTIGUOUS(points_coords);
CHECK_IS_FLOAT(centers_coords);
CHECK_IS_FLOAT(points_coords);
int b = centers_coords.size(0);
int m = centers_coords.size(2);
int n = points_coords.size(2);
at::Tensor neighbors_indices = torch::zeros(
{b, m, num_neighbors},
at::device(centers_coords.device()).dtype(at::ScalarType::Int));
ball_query(b, n, m, radius * radius, num_neighbors,
centers_coords.data_ptr<float>(),
points_coords.data_ptr<float>(),
neighbors_indices.data_ptr<int>());
return neighbors_indices;
}

View file

@ -0,0 +1,59 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: ball query
Args:
b : batch size
n : number of points in point clouds
m : number of query centers
r2 : ball query radius ** 2
u : maximum number of neighbors
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
points_coords : coordinates of points, FloatTensor[b, 3, n]
neighbors_indices : neighbor indices in points, IntTensor[b, m, u]
*/
__global__ void ball_query_kernel(int b, int n, int m, float r2, int u,
const float *__restrict__ centers_coords,
const float *__restrict__ points_coords,
int *__restrict__ neighbors_indices) {
int batch_index = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
points_coords += batch_index * n * 3;
centers_coords += batch_index * m * 3;
neighbors_indices += batch_index * m * u;
for (int j = index; j < m; j += stride) {
float center_x = centers_coords[j];
float center_y = centers_coords[j + m];
float center_z = centers_coords[j + m + m];
for (int k = 0, cnt = 0; k < n && cnt < u; ++k) {
float dx = center_x - points_coords[k];
float dy = center_y - points_coords[k + n];
float dz = center_z - points_coords[k + n + n];
float d2 = dx * dx + dy * dy + dz * dz;
if (d2 < r2) {
if (cnt == 0) {
for (int v = 0; v < u; ++v) {
neighbors_indices[j * u + v] = k;
}
}
neighbors_indices[j * u + cnt] = k;
++cnt;
}
}
}
}
void ball_query(int b, int n, int m, float r2, int u,
const float *centers_coords, const float *points_coords,
int *neighbors_indices) {
ball_query_kernel<<<b, optimal_num_threads(m), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, n, m, r2, u, centers_coords, points_coords, neighbors_indices);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,8 @@
#ifndef _BALL_QUERY_CUH
#define _BALL_QUERY_CUH
void ball_query(int b, int n, int m, float r2, int u,
const float *centers_coords, const float *points_coords,
int *neighbors_indices);
#endif

View file

@ -0,0 +1,10 @@
#ifndef _BALL_QUERY_HPP
#define _BALL_QUERY_HPP
#include <torch/extension.h>
at::Tensor ball_query_forward(at::Tensor centers_coords,
at::Tensor points_coords, const float radius,
const int num_neighbors);
#endif

View file

@ -0,0 +1,37 @@
#include <pybind11/pybind11.h>
#include "ball_query/ball_query.hpp"
#include "grouping/grouping.hpp"
#include "interpolate/neighbor_interpolate.hpp"
#include "interpolate/trilinear_devox.hpp"
#include "sampling/sampling.hpp"
#include "voxelization/vox.hpp"
PYBIND11_MODULE(_pvcnn_backend, m) {
m.def("gather_features_forward", &gather_features_forward,
"Gather Centers' Features forward (CUDA)");
m.def("gather_features_backward", &gather_features_backward,
"Gather Centers' Features backward (CUDA)");
m.def("furthest_point_sampling", &furthest_point_sampling_forward,
"Furthest Point Sampling (CUDA)");
m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)");
m.def("grouping_forward", &grouping_forward,
"Grouping Features forward (CUDA)");
m.def("grouping_backward", &grouping_backward,
"Grouping Features backward (CUDA)");
m.def("three_nearest_neighbors_interpolate_forward",
&three_nearest_neighbors_interpolate_forward,
"3 Nearest Neighbors Interpolate forward (CUDA)");
m.def("three_nearest_neighbors_interpolate_backward",
&three_nearest_neighbors_interpolate_backward,
"3 Nearest Neighbors Interpolate backward (CUDA)");
m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward,
"Trilinear Devoxelization forward (CUDA)");
m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward,
"Trilinear Devoxelization backward (CUDA)");
m.def("avg_voxelize_forward", &avg_voxelize_forward,
"Voxelization forward with average pooling (CUDA)");
m.def("avg_voxelize_backward", &avg_voxelize_backward,
"Voxelization backward (CUDA)");
}

View file

@ -0,0 +1,39 @@
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#define MAXIMUM_THREADS 512
inline int optimal_num_threads(int work_size) {
const int pow_2 = std::log2(static_cast<double>(work_size));
return max(min(1 << pow_2, MAXIMUM_THREADS), 1);
}
inline dim3 optimal_block_config(int x, int y) {
const int x_threads = optimal_num_threads(x);
const int y_threads =
max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1);
dim3 block_config(x_threads, y_threads, 1);
return block_config;
}
#define CUDA_CHECK_ERRORS() \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
__FILE__); \
exit(-1); \
} \
}
#endif

View file

@ -0,0 +1,44 @@
#include "grouping.hpp"
#include "grouping.cuh"
#include "../utils.hpp"
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) {
CHECK_CUDA(features);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(features);
CHECK_IS_INT(indices);
int b = features.size(0);
int c = features.size(1);
int n = features.size(2);
int m = indices.size(1);
int u = indices.size(2);
at::Tensor output = torch::zeros(
{b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float));
grouping(b, c, n, m, u, features.data_ptr<float>(), indices.data_ptr<int>(),
output.data_ptr<float>());
return output;
}
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
const int n) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
int b = grad_y.size(0);
int c = grad_y.size(1);
int m = indices.size(1);
int u = indices.size(2);
at::Tensor grad_x = torch::zeros(
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
grouping_grad(b, c, n, m, u, grad_y.data_ptr<float>(),
indices.data_ptr<int>(), grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,85 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: grouping features of neighbors (forward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query centers
u : maximum number of neighbors
features: points' features, FloatTensor[b, c, n]
indices : neighbor indices in points, IntTensor[b, m, u]
out : gathered features, FloatTensor[b, c, m, u]
*/
__global__ void grouping_kernel(int b, int c, int n, int m, int u,
const float *__restrict__ features,
const int *__restrict__ indices,
float *__restrict__ out) {
int batch_index = blockIdx.x;
features += batch_index * n * c;
indices += batch_index * m * u;
out += batch_index * m * u * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * m; i += stride) {
const int l = i / m;
const int j = i % m;
for (int k = 0; k < u; ++k) {
out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]];
}
}
}
void grouping(int b, int c, int n, int m, int u, const float *features,
const int *indices, float *out) {
grouping_kernel<<<b, optimal_block_config(m, c), 0,
at::cuda::getCurrentCUDAStream()>>>(b, c, n, m, u, features,
indices, out);
CUDA_CHECK_ERRORS();
}
/*
Function: grouping features of neighbors (backward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query centers
u : maximum number of neighbors
grad_y : grad of gathered features, FloatTensor[b, c, m, u]
indices : neighbor indices in points, IntTensor[b, m, u]
grad_x: grad of points' features, FloatTensor[b, c, n]
*/
__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u,
const float *__restrict__ grad_y,
const int *__restrict__ indices,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
grad_y += batch_index * m * u * c;
indices += batch_index * m * u;
grad_x += batch_index * n * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * m; i += stride) {
const int l = i / m;
const int j = i % m;
for (int k = 0; k < u; ++k) {
atomicAdd(grad_x + l * n + indices[j * u + k],
grad_y[(l * m + j) * u + k]);
}
}
}
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
const int *indices, float *grad_x) {
grouping_grad_kernel<<<b, optimal_block_config(m, c), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, u, grad_y, indices, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,9 @@
#ifndef _GROUPING_CUH
#define _GROUPING_CUH
void grouping(int b, int c, int n, int m, int u, const float *features,
const int *indices, float *out);
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
const int *indices, float *grad_x);
#endif

View file

@ -0,0 +1,10 @@
#ifndef _GROUPING_HPP
#define _GROUPING_HPP
#include <torch/extension.h>
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices);
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
const int n);
#endif

View file

@ -0,0 +1,65 @@
#include "neighbor_interpolate.hpp"
#include "neighbor_interpolate.cuh"
#include "../utils.hpp"
std::vector<at::Tensor>
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
at::Tensor centers_coords,
at::Tensor centers_features) {
CHECK_CUDA(points_coords);
CHECK_CUDA(centers_coords);
CHECK_CUDA(centers_features);
CHECK_CONTIGUOUS(points_coords);
CHECK_CONTIGUOUS(centers_coords);
CHECK_CONTIGUOUS(centers_features);
CHECK_IS_FLOAT(points_coords);
CHECK_IS_FLOAT(centers_coords);
CHECK_IS_FLOAT(centers_features);
int b = centers_features.size(0);
int c = centers_features.size(1);
int m = centers_features.size(2);
int n = points_coords.size(2);
at::Tensor indices = torch::zeros(
{b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int));
at::Tensor weights = torch::zeros(
{b, 3, n},
at::device(points_coords.device()).dtype(at::ScalarType::Float));
at::Tensor output = torch::zeros(
{b, c, n},
at::device(centers_features.device()).dtype(at::ScalarType::Float));
three_nearest_neighbors_interpolate(
b, c, m, n, points_coords.data_ptr<float>(),
centers_coords.data_ptr<float>(), centers_features.data_ptr<float>(),
indices.data_ptr<int>(), weights.data_ptr<float>(),
output.data_ptr<float>());
return {output, indices, weights};
}
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
at::Tensor indices,
at::Tensor weights,
const int m) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CUDA(weights);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_CONTIGUOUS(weights);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
CHECK_IS_FLOAT(weights);
int b = grad_y.size(0);
int c = grad_y.size(1);
int n = grad_y.size(2);
at::Tensor grad_x = torch::zeros(
{b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
three_nearest_neighbors_interpolate_grad(
b, c, n, m, grad_y.data_ptr<float>(), indices.data_ptr<int>(),
weights.data_ptr<float>(), grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,181 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: three nearest neighbors
Args:
b : batch size
n : number of points in point clouds
m : number of query centers
points_coords : coordinates of points, FloatTensor[b, 3, n]
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
weights : weights of nearest 3 centers to the point,
FloatTensor[b, 3, n]
indices : indices of nearest 3 centers to the point,
IntTensor[b, 3, n]
*/
__global__ void three_nearest_neighbors_kernel(
int b, int n, int m, const float *__restrict__ points_coords,
const float *__restrict__ centers_coords, float *__restrict__ weights,
int *__restrict__ indices) {
int batch_index = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
points_coords += batch_index * 3 * n;
weights += batch_index * 3 * n;
indices += batch_index * 3 * n;
centers_coords += batch_index * 3 * m;
for (int j = index; j < n; j += stride) {
float ux = points_coords[j];
float uy = points_coords[j + n];
float uz = points_coords[j + n + n];
double best0 = 1e40, best1 = 1e40, best2 = 1e40;
int besti0 = 0, besti1 = 0, besti2 = 0;
for (int k = 0; k < m; ++k) {
float x = centers_coords[k];
float y = centers_coords[k + m];
float z = centers_coords[k + m + m];
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best2) {
best2 = d;
besti2 = k;
if (d < best1) {
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
if (d < best0) {
best1 = best0;
besti1 = besti0;
best0 = d;
besti0 = k;
}
}
}
}
best0 = max(min(1e10f, best0), 1e-10f);
best1 = max(min(1e10f, best1), 1e-10f);
best2 = max(min(1e10f, best2), 1e-10f);
float d0d1 = best0 * best1;
float d0d2 = best0 * best2;
float d1d2 = best1 * best2;
float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2);
weights[j] = d1d2 * d0d1d2;
indices[j] = besti0;
weights[j + n] = d0d2 * d0d1d2;
indices[j + n] = besti1;
weights[j + n + n] = d0d1 * d0d1d2;
indices[j + n + n] = besti2;
}
}
/*
Function: interpolate three nearest neighbors (forward)
Args:
b : batch size
c : #channels of features
m : number of query centers
n : number of points in point clouds
centers_features: features of centers, FloatTensor[b, c, m]
indices : indices of nearest 3 centers to the point,
IntTensor[b, 3, n]
weights : weights for interpolation, FloatTensor[b, 3, n]
out : features of points, FloatTensor[b, c, n]
*/
__global__ void three_nearest_neighbors_interpolate_kernel(
int b, int c, int m, int n, const float *__restrict__ centers_features,
const int *__restrict__ indices, const float *__restrict__ weights,
float *__restrict__ out) {
int batch_index = blockIdx.x;
centers_features += batch_index * m * c;
indices += batch_index * n * 3;
weights += batch_index * n * 3;
out += batch_index * n * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
float w1 = weights[j];
float w2 = weights[j + n];
float w3 = weights[j + n + n];
int i1 = indices[j];
int i2 = indices[j + n];
int i3 = indices[j + n + n];
out[i] = centers_features[l * m + i1] * w1 +
centers_features[l * m + i2] * w2 +
centers_features[l * m + i3] * w3;
}
}
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
const float *points_coords,
const float *centers_coords,
const float *centers_features,
int *indices, float *weights,
float *out) {
three_nearest_neighbors_kernel<<<b, optimal_num_threads(n), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, n, m, points_coords, centers_coords, weights, indices);
three_nearest_neighbors_interpolate_kernel<<<
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
b, c, m, n, centers_features, indices, weights, out);
CUDA_CHECK_ERRORS();
}
/*
Function: interpolate three nearest neighbors (backward)
Args:
b : batch size
c : #channels of features
m : number of query centers
n : number of points in point clouds
grad_y : grad of features of points, FloatTensor[b, c, n]
indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n]
weights : weights for interpolation, FloatTensor[b, 3, n]
grad_x : grad of features of centers, FloatTensor[b, c, m]
*/
__global__ void three_nearest_neighbors_interpolate_grad_kernel(
int b, int c, int n, int m, const float *__restrict__ grad_y,
const int *__restrict__ indices, const float *__restrict__ weights,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
grad_y += batch_index * n * c;
indices += batch_index * n * 3;
weights += batch_index * n * 3;
grad_x += batch_index * m * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
float w1 = weights[j];
float w2 = weights[j + n];
float w3 = weights[j + n + n];
int i1 = indices[j];
int i2 = indices[j + n];
int i3 = indices[j + n + n];
atomicAdd(grad_x + l * m + i1, grad_y[i] * w1);
atomicAdd(grad_x + l * m + i2, grad_y[i] * w2);
atomicAdd(grad_x + l * m + i3, grad_y[i] * w3);
}
}
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
const float *grad_y,
const int *indices,
const float *weights,
float *grad_x) {
three_nearest_neighbors_interpolate_grad_kernel<<<
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, grad_y, indices, weights, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,16 @@
#ifndef _NEIGHBOR_INTERPOLATE_CUH
#define _NEIGHBOR_INTERPOLATE_CUH
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
const float *points_coords,
const float *centers_coords,
const float *centers_features,
int *indices, float *weights,
float *out);
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
const float *grad_y,
const int *indices,
const float *weights,
float *grad_x);
#endif

View file

@ -0,0 +1,16 @@
#ifndef _NEIGHBOR_INTERPOLATE_HPP
#define _NEIGHBOR_INTERPOLATE_HPP
#include <torch/extension.h>
#include <vector>
std::vector<at::Tensor>
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
at::Tensor centers_coords,
at::Tensor centers_features);
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
at::Tensor indices,
at::Tensor weights,
const int m);
#endif

View file

@ -0,0 +1,91 @@
#include "trilinear_devox.hpp"
#include "trilinear_devox.cuh"
#include "../utils.hpp"
/*
Function: trilinear devoxelization (forward)
Args:
r : voxel resolution
trainig : whether is training mode
coords : the coordinates of points, FloatTensor[b, 3, n]
features : features, FloatTensor[b, c, s], s = r ** 3
Return:
outs : outputs, FloatTensor[b, c, n]
inds : the voxel coordinates of point cube, IntTensor[b, 8, n]
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
*/
std::vector<at::Tensor>
trilinear_devoxelize_forward(const int r, const bool is_training,
const at::Tensor coords,
const at::Tensor features) {
CHECK_CUDA(features);
CHECK_CUDA(coords);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(coords);
CHECK_IS_FLOAT(features);
CHECK_IS_FLOAT(coords);
int b = features.size(0);
int c = features.size(1);
int n = coords.size(2);
int r2 = r * r;
int r3 = r2 * r;
at::Tensor outs = torch::zeros(
{b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float));
if (is_training) {
at::Tensor inds = torch::zeros(
{b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int));
at::Tensor wgts = torch::zeros(
{b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float));
trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr<float>(),
features.data_ptr<float>(), inds.data_ptr<int>(),
wgts.data_ptr<float>(), outs.data_ptr<float>());
return {outs, inds, wgts};
} else {
at::Tensor inds = torch::zeros(
{1}, at::device(features.device()).dtype(at::ScalarType::Int));
at::Tensor wgts = torch::zeros(
{1}, at::device(features.device()).dtype(at::ScalarType::Float));
trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr<float>(),
features.data_ptr<float>(), inds.data_ptr<int>(),
wgts.data_ptr<float>(), outs.data_ptr<float>());
return {outs, inds, wgts};
}
}
/*
Function: trilinear devoxelization (backward)
Args:
grad_y : grad outputs, FloatTensor[b, c, n]
indices : the voxel coordinates of point cube, IntTensor[b, 8, n]
weights : weight for trilinear interpolation, FloatTensor[b, 8, n]
r : voxel resolution
Return:
grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3
*/
at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor weights,
const int r) {
CHECK_CUDA(grad_y);
CHECK_CUDA(weights);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(weights);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_FLOAT(weights);
CHECK_IS_INT(indices);
int b = grad_y.size(0);
int c = grad_y.size(1);
int n = grad_y.size(2);
int r3 = r * r * r;
at::Tensor grad_x = torch::zeros(
{b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr<int>(),
weights.data_ptr<float>(), grad_y.data_ptr<float>(),
grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,178 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: trilinear devoxlization (forward)
Args:
b : batch size
c : #channels
n : number of points
r : voxel resolution
r2 : r ** 2
r3 : r ** 3
coords : the coordinates of points, FloatTensor[b, 3, n]
feat : features, FloatTensor[b, c, r3]
inds : the voxel indices of point cube, IntTensor[b, 8, n]
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
outs : outputs, FloatTensor[b, c, n]
*/
__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2,
int r3, bool is_training,
const float *__restrict__ coords,
const float *__restrict__ feat,
int *__restrict__ inds,
float *__restrict__ wgts,
float *__restrict__ outs) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
coords += batch_index * n * 3;
inds += batch_index * n * 8;
wgts += batch_index * n * 8;
feat += batch_index * c * r3;
outs += batch_index * c * n;
for (int i = index; i < n; i += stride) {
float x = coords[i];
float y = coords[i + n];
float z = coords[i + n + n];
float x_lo_f = floorf(x);
float y_lo_f = floorf(y);
float z_lo_f = floorf(z);
float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f)
float y_d_1 = y - y_lo_f;
float z_d_1 = z - z_lo_f;
float x_d_0 = 1.0f - x_d_1;
float y_d_0 = 1.0f - y_d_1;
float z_d_0 = 1.0f - z_d_1;
float wgt000 = x_d_0 * y_d_0 * z_d_0;
float wgt001 = x_d_0 * y_d_0 * z_d_1;
float wgt010 = x_d_0 * y_d_1 * z_d_0;
float wgt011 = x_d_0 * y_d_1 * z_d_1;
float wgt100 = x_d_1 * y_d_0 * z_d_0;
float wgt101 = x_d_1 * y_d_0 * z_d_1;
float wgt110 = x_d_1 * y_d_1 * z_d_0;
float wgt111 = x_d_1 * y_d_1 * z_d_1;
int x_lo = static_cast<int>(x_lo_f);
int y_lo = static_cast<int>(y_lo_f);
int z_lo = static_cast<int>(z_lo_f);
int x_hi = (x_d_1 > 0) ? -1 : 0;
int y_hi = (y_d_1 > 0) ? -1 : 0;
int z_hi = (z_d_1 > 0) ? 1 : 0;
int idx000 = x_lo * r2 + y_lo * r + z_lo;
int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi;
int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo;
int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi;
int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo;
int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi;
int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo;
int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi;
if (is_training) {
wgts[i] = wgt000;
wgts[i + n] = wgt001;
wgts[i + n * 2] = wgt010;
wgts[i + n * 3] = wgt011;
wgts[i + n * 4] = wgt100;
wgts[i + n * 5] = wgt101;
wgts[i + n * 6] = wgt110;
wgts[i + n * 7] = wgt111;
inds[i] = idx000;
inds[i + n] = idx001;
inds[i + n * 2] = idx010;
inds[i + n * 3] = idx011;
inds[i + n * 4] = idx100;
inds[i + n * 5] = idx101;
inds[i + n * 6] = idx110;
inds[i + n * 7] = idx111;
}
for (int j = 0; j < c; j++) {
int jr3 = j * r3;
outs[j * n + i] =
wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] +
wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] +
wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] +
wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111];
}
}
}
/*
Function: trilinear devoxlization (backward)
Args:
b : batch size
c : #channels
n : number of points
r3 : voxel cube size = voxel resolution ** 3
inds : the voxel indices of point cube, IntTensor[b, 8, n]
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
grad_y : grad outputs, FloatTensor[b, c, n]
grad_x : grad inputs, FloatTensor[b, c, r3]
*/
__global__ void trilinear_devoxelize_grad_kernel(
int b, int c, int n, int r3, const int *__restrict__ inds,
const float *__restrict__ wgts, const float *__restrict__ grad_y,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
inds += batch_index * n * 8;
wgts += batch_index * n * 8;
grad_x += batch_index * c * r3;
grad_y += batch_index * c * n;
for (int i = index; i < n; i += stride) {
int idx000 = inds[i];
int idx001 = inds[i + n];
int idx010 = inds[i + n * 2];
int idx011 = inds[i + n * 3];
int idx100 = inds[i + n * 4];
int idx101 = inds[i + n * 5];
int idx110 = inds[i + n * 6];
int idx111 = inds[i + n * 7];
float wgt000 = wgts[i];
float wgt001 = wgts[i + n];
float wgt010 = wgts[i + n * 2];
float wgt011 = wgts[i + n * 3];
float wgt100 = wgts[i + n * 4];
float wgt101 = wgts[i + n * 5];
float wgt110 = wgts[i + n * 6];
float wgt111 = wgts[i + n * 7];
for (int j = 0; j < c; j++) {
int jr3 = j * r3;
float g = grad_y[j * n + i];
atomicAdd(grad_x + jr3 + idx000, wgt000 * g);
atomicAdd(grad_x + jr3 + idx001, wgt001 * g);
atomicAdd(grad_x + jr3 + idx010, wgt010 * g);
atomicAdd(grad_x + jr3 + idx011, wgt011 * g);
atomicAdd(grad_x + jr3 + idx100, wgt100 * g);
atomicAdd(grad_x + jr3 + idx101, wgt101 * g);
atomicAdd(grad_x + jr3 + idx110, wgt110 * g);
atomicAdd(grad_x + jr3 + idx111, wgt111 * g);
}
}
}
void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
bool training, const float *coords, const float *feat,
int *inds, float *wgts, float *outs) {
trilinear_devoxelize_kernel<<<b, optimal_num_threads(n)>>>(
b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs);
CUDA_CHECK_ERRORS();
}
void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
const float *wgts, const float *grad_y,
float *grad_x) {
trilinear_devoxelize_grad_kernel<<<b, optimal_num_threads(n)>>>(
b, c, n, r3, inds, wgts, grad_y, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,13 @@
#ifndef _TRILINEAR_DEVOX_CUH
#define _TRILINEAR_DEVOX_CUH
// CUDA function declarations
void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
bool is_training, const float *coords,
const float *feat, int *inds, float *wgts,
float *outs);
void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
const float *wgts, const float *grad_y,
float *grad_x);
#endif

View file

@ -0,0 +1,16 @@
#ifndef _TRILINEAR_DEVOX_HPP
#define _TRILINEAR_DEVOX_HPP
#include <torch/torch.h>
#include <vector>
std::vector<at::Tensor> trilinear_devoxelize_forward(const int r,
const bool is_training,
const at::Tensor coords,
const at::Tensor features);
at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor weights, const int r);
#endif

View file

@ -0,0 +1,58 @@
#include "sampling.hpp"
#include "sampling.cuh"
#include "../utils.hpp"
at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) {
CHECK_CUDA(features);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(features);
CHECK_IS_INT(indices);
int b = features.size(0);
int c = features.size(1);
int n = features.size(2);
int m = indices.size(1);
at::Tensor output = torch::zeros(
{b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float));
gather_features(b, c, n, m, features.data_ptr<float>(),
indices.data_ptr<int>(), output.data_ptr<float>());
return output;
}
at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
const int n) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
int b = grad_y.size(0);
int c = grad_y.size(1);
at::Tensor grad_x = torch::zeros(
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr<float>(),
indices.data_ptr<int>(), grad_x.data_ptr<float>());
return grad_x;
}
at::Tensor furthest_point_sampling_forward(at::Tensor coords,
const int num_samples) {
CHECK_CUDA(coords);
CHECK_CONTIGUOUS(coords);
CHECK_IS_FLOAT(coords);
int b = coords.size(0);
int n = coords.size(2);
at::Tensor indices = torch::zeros(
{b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int));
at::Tensor distances = torch::full(
{b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float));
furthest_point_sampling(b, n, num_samples, coords.data_ptr<float>(),
distances.data_ptr<float>(), indices.data_ptr<int>());
return indices;
}

View file

@ -0,0 +1,174 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: gather centers' features (forward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query/sampled centers
features: points' features, FloatTensor[b, c, n]
indices : centers' indices in points, IntTensor[b, m]
out : gathered features, FloatTensor[b, c, m]
*/
__global__ void gather_features_kernel(int b, int c, int n, int m,
const float *__restrict__ features,
const int *__restrict__ indices,
float *__restrict__ out) {
int batch_index = blockIdx.x;
int channel_index = blockIdx.y;
int temp_index = batch_index * c + channel_index;
features += temp_index * n;
indices += batch_index * m;
out += temp_index * m;
for (int j = threadIdx.x; j < m; j += blockDim.x) {
out[j] = features[indices[j]];
}
}
void gather_features(int b, int c, int n, int m, const float *features,
const int *indices, float *out) {
gather_features_kernel<<<dim3(b, c, 1), optimal_num_threads(m), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, features, indices, out);
CUDA_CHECK_ERRORS();
}
/*
Function: gather centers' features (backward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query/sampled centers
grad_y : grad of gathered features, FloatTensor[b, c, m]
indices : centers' indices in points, IntTensor[b, m]
grad_x : grad of points' features, FloatTensor[b, c, n]
*/
__global__ void gather_features_grad_kernel(int b, int c, int n, int m,
const float *__restrict__ grad_y,
const int *__restrict__ indices,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
int channel_index = blockIdx.y;
int temp_index = batch_index * c + channel_index;
grad_y += temp_index * m;
indices += batch_index * m;
grad_x += temp_index * n;
for (int j = threadIdx.x; j < m; j += blockDim.x) {
atomicAdd(grad_x + indices[j], grad_y[j]);
}
}
void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
const int *indices, float *grad_x) {
gather_features_grad_kernel<<<dim3(b, c, 1), optimal_num_threads(m), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, grad_y, indices, grad_x);
CUDA_CHECK_ERRORS();
}
/*
Function: furthest point sampling
Args:
b : batch size
n : number of points in point clouds
m : number of query/sampled centers
coords : points' coords, FloatTensor[b, 3, n]
distances : minimum distance of a point to the set, IntTensor[b, n]
indices : sampled centers' indices in points, IntTensor[b, m]
*/
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ coords,
float *__restrict__ distances,
int *__restrict__ indices) {
if (m <= 0)
return;
int batch_index = blockIdx.x;
coords += batch_index * n * 3;
distances += batch_index * n;
indices += batch_index * m;
const int BlockSize = 512;
__shared__ float dists[BlockSize];
__shared__ int dists_i[BlockSize];
const int BufferSize = 3072;
__shared__ float buf[BufferSize * 3];
int old = 0;
if (threadIdx.x == 0)
indices[0] = old;
for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) {
buf[j] = coords[j];
buf[j + BufferSize] = coords[j + n];
buf[j + BufferSize + BufferSize] = coords[j + n + n];
}
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0; // best index
float best = -1; // farthest distance
// calculating the distance with the latest sampled point
float x1 = coords[old];
float y1 = coords[old + n];
float z1 = coords[old + n + n];
for (int k = threadIdx.x; k < n; k += blockDim.x) {
// fetch distance at block n, thread k
float td = distances[k];
float x2, y2, z2;
if (k < BufferSize) {
x2 = buf[k];
y2 = buf[k + BufferSize];
z2 = buf[k + BufferSize + BufferSize];
} else {
x2 = coords[k];
y2 = coords[k + n];
z2 = coords[k + n + n];
}
float d =
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, td);
// update "point-to-set" distance
if (d2 != td)
distances[k] = d2;
// update the farthest distance at sample step j
if (d2 > best) {
best = d2;
besti = k;
}
}
dists[threadIdx.x] = best;
dists_i[threadIdx.x] = besti;
for (int u = 0; (1 << u) < blockDim.x; u++) {
__syncthreads();
if (threadIdx.x < (blockDim.x >> (u + 1))) {
int i1 = (threadIdx.x * 2) << u;
int i2 = (threadIdx.x * 2 + 1) << u;
if (dists[i1] < dists[i2]) {
dists[i1] = dists[i2];
dists_i[i1] = dists_i[i2];
}
}
}
__syncthreads();
// finish sample step j; old is the sampled index
old = dists_i[0];
if (threadIdx.x == 0)
indices[j] = old;
}
}
void furthest_point_sampling(int b, int n, int m, const float *coords,
float *distances, int *indices) {
furthest_point_sampling_kernel<<<b, 512>>>(b, n, m, coords, distances,
indices);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,11 @@
#ifndef _SAMPLING_CUH
#define _SAMPLING_CUH
void gather_features(int b, int c, int n, int m, const float *features,
const int *indices, float *out);
void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
const int *indices, float *grad_x);
void furthest_point_sampling(int b, int n, int m, const float *coords,
float *distances, int *indices);
#endif

View file

@ -0,0 +1,12 @@
#ifndef _SAMPLING_HPP
#define _SAMPLING_HPP
#include <torch/extension.h>
at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices);
at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
const int n);
at::Tensor furthest_point_sampling_forward(at::Tensor coords,
const int num_samples);
#endif

View file

@ -0,0 +1,20 @@
#ifndef _UTILS_HPP
#define _UTILS_HPP
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
#x " must be an int tensor")
#define CHECK_IS_FLOAT(x) \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
#x " must be a float tensor")
#endif

View file

@ -0,0 +1,76 @@
#include "vox.hpp"
#include "vox.cuh"
#include "../utils.hpp"
/*
Function: average pool voxelization (forward)
Args:
features: features, FloatTensor[b, c, n]
coords : coords of each point, IntTensor[b, 3, n]
resolution : voxel resolution
Return:
out : outputs, FloatTensor[b, c, s], s = r ** 3
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
*/
std::vector<at::Tensor> avg_voxelize_forward(const at::Tensor features,
const at::Tensor coords,
const int resolution) {
CHECK_CUDA(features);
CHECK_CUDA(coords);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(coords);
CHECK_IS_FLOAT(features);
CHECK_IS_INT(coords);
int b = features.size(0);
int c = features.size(1);
int n = features.size(2);
int r = resolution;
int r2 = r * r;
int r3 = r2 * r;
at::Tensor ind = torch::zeros(
{b, n}, at::device(features.device()).dtype(at::ScalarType::Int));
at::Tensor out = torch::zeros(
{b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float));
at::Tensor cnt = torch::zeros(
{b, r3}, at::device(features.device()).dtype(at::ScalarType::Int));
avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr<int>(),
features.data_ptr<float>(), ind.data_ptr<int>(),
cnt.data_ptr<int>(), out.data_ptr<float>());
return {out, ind, cnt};
}
/*
Function: average pool voxelization (backward)
Args:
grad_y : grad outputs, FloatTensor[b, c, s]
indices: voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
Return:
grad_x : grad inputs, FloatTensor[b, c, n]
*/
at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor cnt) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CUDA(cnt);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_CONTIGUOUS(cnt);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
CHECK_IS_INT(cnt);
int b = grad_y.size(0);
int c = grad_y.size(1);
int s = grad_y.size(2);
int n = indices.size(1);
at::Tensor grad_x = torch::zeros(
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
avg_voxelize_grad(b, c, n, s, indices.data_ptr<int>(), cnt.data_ptr<int>(),
grad_y.data_ptr<float>(), grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,126 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: get how many points in each voxel grid
Args:
b : batch size
n : number of points
r : voxel resolution
r2 : = r * r
r3 : s, voxel cube size = r ** 3
coords : coords of each point, IntTensor[b, 3, n]
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
*/
__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3,
const int *__restrict__ coords,
int *__restrict__ ind, int *cnt) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
coords += batch_index * n * 3;
ind += batch_index * n;
cnt += batch_index * r3;
for (int i = index; i < n; i += stride) {
// if (ind[i] == -1)
// continue;
ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n];
atomicAdd(cnt + ind[i], 1);
}
}
/*
Function: average pool voxelization (forward)
Args:
b : batch size
c : #channels
n : number of points
s : voxel cube size = voxel resolution ** 3
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
feat: features, FloatTensor[b, c, n]
out : outputs, FloatTensor[b, c, s]
*/
__global__ void avg_voxelize_kernel(int b, int c, int n, int s,
const int *__restrict__ ind,
const int *__restrict__ cnt,
const float *__restrict__ feat,
float *__restrict__ out) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
ind += batch_index * n;
feat += batch_index * c * n;
out += batch_index * c * s;
cnt += batch_index * s;
for (int i = index; i < n; i += stride) {
int pos = ind[i];
// if (pos == -1)
// continue;
int cur_cnt = cnt[pos];
if (cur_cnt > 0) {
float div_cur_cnt = 1.0 / static_cast<float>(cur_cnt);
for (int j = 0; j < c; j++) {
atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt);
}
}
}
}
/*
Function: average pool voxelization (backward)
Args:
b : batch size
c : #channels
n : number of points
r3 : voxel cube size = voxel resolution ** 3
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
grad_y : grad outputs, FloatTensor[b, c, s]
grad_x : grad inputs, FloatTensor[b, c, n]
*/
__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3,
const int *__restrict__ ind,
const int *__restrict__ cnt,
const float *__restrict__ grad_y,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
ind += batch_index * n;
grad_x += batch_index * c * n;
grad_y += batch_index * c * r3;
cnt += batch_index * r3;
for (int i = index; i < n; i += stride) {
int pos = ind[i];
// if (pos == -1)
// continue;
int cur_cnt = cnt[pos];
if (cur_cnt > 0) {
float div_cur_cnt = 1.0 / static_cast<float>(cur_cnt);
for (int j = 0; j < c; j++) {
atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt);
}
}
}
}
void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
const float *feat, int *ind, int *cnt, float *out) {
grid_stats_kernel<<<b, optimal_num_threads(n)>>>(b, n, r, r2, r3, coords, ind,
cnt);
avg_voxelize_kernel<<<b, optimal_num_threads(n)>>>(b, c, n, r3, ind, cnt,
feat, out);
CUDA_CHECK_ERRORS();
}
void avg_voxelize_grad(int b, int c, int n, int s, const int *ind,
const int *cnt, const float *grad_y, float *grad_x) {
avg_voxelize_grad_kernel<<<b, optimal_num_threads(n)>>>(b, c, n, s, ind, cnt,
grad_y, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,10 @@
#ifndef _VOX_CUH
#define _VOX_CUH
// CUDA function declarations
void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
const float *feat, int *ind, int *cnt, float *out);
void avg_voxelize_grad(int b, int c, int n, int s, const int *idx,
const int *cnt, const float *grad_y, float *grad_x);
#endif

View file

@ -0,0 +1,15 @@
#ifndef _VOX_HPP
#define _VOX_HPP
#include <torch/torch.h>
#include <vector>
std::vector<at::Tensor> avg_voxelize_forward(const at::Tensor features,
const at::Tensor coords,
const int resolution);
at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor cnt);
#endif

View file

@ -0,0 +1,40 @@
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['avg_voxelize']
class AvgVoxelization(Function):
@staticmethod
def forward(ctx, features, coords, resolution):
"""
:param ctx:
:param features: Features of the point cloud, FloatTensor[B, C, N]
:param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N]
:param resolution: Voxel resolution
:return:
Voxelized Features, FloatTensor[B, C, R, R, R]
"""
features = features.contiguous()
coords = coords.int().contiguous()
b, c, _ = features.shape
out, indices, counts = _backend.avg_voxelize_forward(features, coords, resolution)
ctx.save_for_backward(indices, counts)
return out.view(b, c, resolution, resolution, resolution)
@staticmethod
def backward(ctx, grad_output):
"""
:param ctx:
:param grad_output: gradient of output, FloatTensor[B, C, R, R, R]
:return:
gradient of inputs, FloatTensor[B, C, N]
"""
b, c = grad_output.shape[:2]
indices, counts = ctx.saved_tensors
grad_features = _backend.avg_voxelize_backward(grad_output.contiguous().view(b, c, -1), indices, counts)
return grad_features, None, None
avg_voxelize = AvgVoxelization.apply

10
modules/loss.py Normal file
View file

@ -0,0 +1,10 @@
import torch.nn as nn
import modules.functional as F
__all__ = ['KLLoss']
class KLLoss(nn.Module):
def forward(self, x, y):
return F.kl_loss(x, y)

113
modules/pointnet.py Normal file
View file

@ -0,0 +1,113 @@
import torch
import torch.nn as nn
import modules.functional as F
from modules.ball_query import BallQuery
from modules.shared_mlp import SharedMLP
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
class PointNetAModule(nn.Module):
def __init__(self, in_channels, out_channels, include_coordinates=True):
super().__init__()
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]]
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels]
mlps = []
total_out_channels = 0
for _out_channels in out_channels:
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=1)
)
total_out_channels += _out_channels[-1]
self.include_coordinates = include_coordinates
self.out_channels = total_out_channels
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords = inputs
if self.include_coordinates:
features = torch.cat([features, coords], dim=1)
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
if len(self.mlps) > 1:
features_list = []
for mlp in self.mlps:
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
return torch.cat(features_list, dim=1), coords
else:
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
def extra_repr(self):
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
class PointNetSAModule(nn.Module):
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
super().__init__()
if not isinstance(radius, (list, tuple)):
radius = [radius]
if not isinstance(num_neighbors, (list, tuple)):
num_neighbors = [num_neighbors] * len(radius)
assert len(radius) == len(num_neighbors)
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]] * len(radius)
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels] * len(radius)
assert len(radius) == len(out_channels)
groupers, mlps = [], []
total_out_channels = 0
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
groupers.append(
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
)
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=2)
)
total_out_channels += _out_channels[-1]
self.num_centers = num_centers
self.out_channels = total_out_channels
self.groupers = nn.ModuleList(groupers)
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords, temb = inputs
centers_coords = F.furthest_point_sample(coords, self.num_centers)
features_list = []
for grouper, mlp in zip(self.groupers, self.mlps):
features, temb = mlp(grouper(coords, centers_coords, temb, features))
features_list.append(features.max(dim=-1).values)
if len(features_list) > 1:
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
else:
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
def extra_repr(self):
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
class PointNetFPModule(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
def forward(self, inputs):
if len(inputs) == 3:
points_coords, centers_coords, centers_features, temb = inputs
points_features = None
else:
points_coords, centers_coords, centers_features, points_features, temb = inputs
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
if points_features is not None:
interpolated_features = torch.cat(
[interpolated_features, points_features], dim=1
)
return self.mlp(interpolated_features), points_coords, interpolated_temb

132
modules/pvconv.py Normal file
View file

@ -0,0 +1,132 @@
import torch.nn as nn
import torch
import modules.functional as F
from modules.voxelization import Voxelization
from modules.shared_mlp import SharedMLP
from modules.se import SE3d
__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU']
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
class Attention(nn.Module):
def __init__(self, in_ch, num_groups, D=3):
super(Attention, self).__init__()
assert in_ch % num_groups == 0
if D == 3:
self.q = nn.Conv3d(in_ch, in_ch, 1)
self.k = nn.Conv3d(in_ch, in_ch, 1)
self.v = nn.Conv3d(in_ch, in_ch, 1)
self.out = nn.Conv3d(in_ch, in_ch, 1)
elif D == 1:
self.q = nn.Conv1d(in_ch, in_ch, 1)
self.k = nn.Conv1d(in_ch, in_ch, 1)
self.v = nn.Conv1d(in_ch, in_ch, 1)
self.out = nn.Conv1d(in_ch, in_ch, 1)
self.norm = nn.GroupNorm(num_groups, in_ch)
self.nonlin = Swish()
self.sm = nn.Softmax(-1)
def forward(self, x):
B, C = x.shape[:2]
h = x
q = self.q(h).reshape(B,C,-1)
k = self.k(h).reshape(B,C,-1)
v = self.v(h).reshape(B,C,-1)
qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5))
w = self.sm(qk)
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:])
h = self.out(h)
x = h + x
x = self.nonlin(self.norm(x))
return x
class PVConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False,
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.resolution = resolution
self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps)
voxel_layers = [
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
Swish()
]
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
voxel_layers += [
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
Attention(out_channels, 8) if attention else Swish()
]
if with_se:
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
self.voxel_layers = nn.Sequential(*voxel_layers)
self.point_features = SharedMLP(in_channels, out_channels)
def forward(self, inputs):
features, coords, temb = inputs
voxel_features, voxel_coords = self.voxelization(features, coords)
voxel_features = self.voxel_layers(voxel_features)
voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
fused_features = voxel_features + self.point_features(features)
return fused_features, coords, temb
class PVConvReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2,
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.resolution = resolution
self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps)
voxel_layers = [
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(leak, True)
]
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
voxel_layers += [
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.BatchNorm3d(out_channels),
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True)
]
if with_se:
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
self.voxel_layers = nn.Sequential(*voxel_layers)
self.point_features = SharedMLP(in_channels, out_channels)
def forward(self, inputs):
features, coords, temb = inputs
voxel_features, voxel_coords = self.voxelization(features, coords)
voxel_features = self.voxel_layers(voxel_features)
voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
fused_features = voxel_features + self.point_features(features)
return fused_features, coords, temb

19
modules/se.py Normal file
View file

@ -0,0 +1,19 @@
import torch.nn as nn
import torch
__all__ = ['SE3d']
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
class SE3d(nn.Module):
def __init__(self, channel, reduction=8, use_relu=False):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(True) if use_relu else Swish() ,
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, inputs):
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)

38
modules/shared_mlp.py Normal file
View file

@ -0,0 +1,38 @@
import torch.nn as nn
import torch
__all__ = ['SharedMLP']
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
class SharedMLP(nn.Module):
def __init__(self, in_channels, out_channels, dim=1):
super().__init__()
if dim == 1:
conv = nn.Conv1d
bn = nn.GroupNorm
elif dim == 2:
conv = nn.Conv2d
bn = nn.GroupNorm
else:
raise ValueError
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
layers = []
for oc in out_channels:
layers.extend([
conv(in_channels, oc, 1),
bn(8, oc),
Swish(),
])
in_channels = oc
self.layers = nn.Sequential(*layers)
def forward(self, inputs):
if isinstance(inputs, (list, tuple)):
return (self.layers(inputs[0]), *inputs[1:])
else:
return self.layers(inputs)

28
modules/voxelization.py Normal file
View file

@ -0,0 +1,28 @@
import torch
import torch.nn as nn
import modules.functional as F
__all__ = ['Voxelization']
class Voxelization(nn.Module):
def __init__(self, resolution, normalize=True, eps=0):
super().__init__()
self.r = int(resolution)
self.normalize = normalize
self.eps = eps
def forward(self, features, coords):
coords = coords.detach()
norm_coords = coords - coords.mean(2, keepdim=True)
if self.normalize:
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
else:
norm_coords = (norm_coords + 1) / 2.0
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
vox_coords = torch.round(norm_coords).to(torch.int32)
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
def extra_repr(self):
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')

26
requirement_voxel.txt Normal file
View file

@ -0,0 +1,26 @@
conda:
python==3.6
torch==1.4.0
torchvision==0.5.0
cudatoolkit==10.1
kaolin==0.1.0
pytorch3d==0.2.5
lutorpy=1.3.7
xmltodict=0.12.0
numba=0.51.2
pycuda=2019.1.2
matplotlib
pip:
torch-scatter==2.0.4
torch-sparse==0.6.1
torch-cluster==1.5.4
torch-spline-conv==1.2.0
descartes==1.1.0
fire==0.3.1
jupyter==1.0.0
opencv_python==4.3.0
Shapely==1.7.0
Pillow==6.2.1
torch_geometric==1.6.0
open3d

View file

@ -0,0 +1,825 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.partnet import GANdatasetPartNet
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category):
train_ds = GANdatasetPartNet('train', data_root, category, npoints)
return train_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
netE.cuda(gpu)
netE.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
netE = netE.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
if opt.netE != '':
ckpt = torch.load(opt.netE)
netE.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.netE != '':
start_epoch = torch.load(opt.netE)['epoch'] + 1
else:
start_epoch = 0
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data['real']
sv_x = data['raw']
sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
noises_batch = noises_init[data['idx']]
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = netE.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(netE)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = netE.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
netE.eval()
m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
with torch.no_grad():
x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
netE.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': netE.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
netE.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/data_v0', help='input batch size')
parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc',
help='input batch size')
parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k',
help='input batch size')
parser.add_argument('--classes', default='Chair')
parser.add_argument('--bs', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

View file

@ -0,0 +1,825 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.partnet import GANdatasetPartNet
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(data_root, data_raw_root, pc_dataroot, npoints, category):
train_ds = GANdatasetPartNet('train', data_root, category, npoints)
return train_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root,opt.pc_dataroot, opt.npoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
netE.cuda(gpu)
netE.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
netE = netE.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
if opt.netE != '':
ckpt = torch.load(opt.netE)
netE.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.netE != '':
start_epoch = torch.load(opt.netE)['epoch'] + 1
else:
start_epoch = 0
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data['real']
sv_x = data['raw']
sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
noises_batch = noises_init[data['idx']]
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = netE.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(netE)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = netE.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
netE.eval()
m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
with torch.no_grad():
x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
netE.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': netE.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
netE.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.data_root, opt.data_raw_root, opt.pc_dataroot, opt.npoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/home/ubuntu/01DATA/partnet/', help='input batch size')
parser.add_argument('--data_raw_root', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_dim32_sdf_pc',
help='input batch size')
parser.add_argument('--pc_dataroot', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k',
help='input batch size')
parser.add_argument('--classes', default='Table')
parser.add_argument('--bs', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

View file

@ -0,0 +1,822 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.partnet import GANdatasetPartNet
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(data_root, npoints, category):
train_ds = GANdatasetPartNet('train', data_root, category, npoints)
return train_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
netE.cuda(gpu)
netE.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
netE = netE.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(netE.parameters(), lr=opt.lrE, weight_decay=opt.e_decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.e_gamma)
if opt.netE != '':
ckpt = torch.load(opt.netE)
netE.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.netE != '':
start_epoch = torch.load(opt.netE)['epoch'] + 1
else:
start_epoch = 0
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data['real']
sv_x = data['raw']
sv_x = torch.cat([sv_x, x[:, :, opt.svpoints:]], dim=-1)
noises_batch = noises_init[data['idx']]
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = netE.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(netE)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(netE.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = netE.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
netE.eval()
m, s = torch.tensor([[0,0,0]]).transpose(0,1)[None], torch.tensor([[1,1,1]]).transpose(0,1)[None]#train_dataset.get_pc_stats(0)
with torch.no_grad():
x_gen_eval = netE.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.detach().cpu()*s+m).transpose(1, 2).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].detach().cpu()*s+m).transpose(1, 2).numpy()*3)
netE.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': netE.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
netE.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.nc, opt.npoints-opt.svpoints)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet/', help='input batch size')
parser.add_argument('--classes', default='Table')
parser.add_argument('--bs', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lrE', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--e_decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--e_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--netE', default='', help="path to netE (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, type=int,help='unit: epoch')
parser.add_argument('--diagIter', default=50, type=int, help='unit: epoch')
parser.add_argument('--vizIter', default=50, type=int,help='unit: epoch')
parser.add_argument('--print_freq', default=50, type=int,help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

View file

View file

@ -0,0 +1,660 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb, denoise_fn, noise_fn=torch.randn):
assert t >= 1
t_vec = torch.empty(x0_part.shape[0], dtype=torch.int64, device=x0_part.device).fill_(t-1)
encoding0 = self.q_sample(x0_part, t_vec)
encoding1 = self.q_sample(x1_part, t_vec)
enc = encoding0 * (1-lamb) + (lamb) * encoding1
img_t = torch.cat([torch.cat([x0_sv[:,:,:int(self.sv_points*(1-lamb))], x1_sv[:,:,:(self.sv_points - int(self.sv_points*(1-lamb)))]], dim=-1), enc], dim=-1)
for k in reversed(range(0,t)):
t_ = torch.empty(img_t.shape[0], dtype=torch.int64, device=img_t.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False).detach()
return img_t
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def interpolate(self, x0_part, x1_part, x0_sv, x1_sv, t, lamb):
return self.diffusion.interpolate(x0_part, x1_part, x0_sv, x1_sv, t, lamb, self._denoise)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category, get_image=True):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std, get_image=get_image,
)
return te_dataset
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
if i!=3:
continue
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
for v in range(20):
recons = []
svs = []
for p in [0,1]:
x = x_all[:,p].transpose(1, 2).contiguous()
img = img_all[:,p]
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
recons.append(recon)
svs.append(x[:, :opt.svpoints,:])
for l, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
# im = np.fliplr(np.flipud(d[-1]))
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p),
(torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,l), 'mode_%03d'%v, 'depth_%d'%p),
(torch.stack(d[:-1], dim=0)* s[0] + m[0]).numpy())
plt.imsave(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v, 'depth_%d.png' % p),
d[-1].permute(1, 2, 0), cmap='gray')
x0_part = recons[0].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda()
x1_part = recons[1].transpose(1, 2).contiguous()[:,:,opt.svpoints:].cuda()
x0_sv = svs[0].transpose(1,2).cuda()
x1_sv = svs[1].transpose(1,2).cuda()
interres = []
for lamb in np.linspace(0.1, 0.9, 5):
res = netE.interpolate(x0_part, x1_part, x0_sv, x1_sv, 1000, lamb)
res = torch.cat([x0_sv, x1_sv, res[:,:,opt.svpoints:]], dim=-1).detach().cpu().transpose(1,2).contiguous()
interres.append(res)
for l, d in enumerate(torch.stack(interres, dim=1)):
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, l), 'mode_%03d' % v),
(d* s[0] + m[0]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, l), 'mode_%03d' % v),
(d * s[0] + m[0]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=None, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -0,0 +1,706 @@
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer2 import write_to_xml_batch, write_to_xml
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory2(self, partial_x, denoise_fn, shape, device, num_save,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
scale = np.exp(np.log(1/total_steps)/num_save)
save_step = total_steps
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
imgs = [img_t.detach().cpu()]
for t in reversed(range(0,total_steps)):
if (t+1) == save_step and t > 0 and len(imgs)<num_save:
imgs.append(img_t.detach().cpu())
save_step = int(save_step * scale)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
imgs.append(img_t.detach().cpu())
assert imgs[-1][:,:,self.sv_points:].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def gen_samples_traj2(self, partial_x, shape, device, noise_fn=torch.randn, num_save=20,
clip_denoised=False,
keep_running=False):
return self.diffusion.p_sample_loop_trajectory2(partial_x, self._denoise, shape=shape, device=device, num_save=num_save, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def generate_video(netE, opt, save_dir):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
# gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
export_to_pc_batch(
os.path.join(save_dir, 'batch_%03d_ply' % i), x_all[:, :opt.svpoints, :].numpy())
write_to_xml_batch(os.path.join(save_dir, 'batch_%03d' % i),
x_all[:, :opt.svpoints, :].numpy(), cat='chair')
for v in range(6):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
gen_all = netE.gen_samples_traj2(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', num_save=55,
clip_denoised=False)
gen_all = torch.stack(gen_all, dim=1).detach().cpu()
gen_all = gen_all.transpose(2, 3).contiguous()
gen_all = gen_all * s[:, None] + m[:, None]
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gen_all), list(img))):
im = np.fliplr(np.flipud(d[-1]))
gen = d[0]
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
# gen
write_to_xml_batch(os.path.join(save_dir, 'batch_%03d'%i, 'sample_%03d/mode_%03d/xml/gen_process/' % (p,v)),
gen.numpy(), cat='chair')
for p, gen in enumerate(gen_all[:,-1]):
Path(os.path.join(save_dir, 'batch_%03d_ply' % i, 'sample_%03d' % p,
'mode_%03d' % v)).mkdir(parents=True, exist_ok=True)
pcwrite(
os.path.join(save_dir, 'batch_%03d_ply' % i, 'sample_%03d/mode_%03d/partial.ply' % (p,v)), gen.numpy())
for k, pcl in enumerate(gen_all[:, -1].cpu().numpy()):
dir_ = os.path.join(save_dir, 'batch_%03d' % i, 'sample_%03d/mode_%03d/xml/rotate_final/' % (k, v))
Path(dir_).mkdir(parents=True, exist_ok=True)
for azim in np.linspace(45, 405 - (360 / 50), 50):
write_to_xml(
os.path.join(dir_, 'azim_%03d.xml' % azim),
pcl, cat='chair', elev=19.471, azim=azim)
def generate_video_redwood(netE, opt, save_dir):
import open3d as o3d
pth = "/viscam/u/alexzhou907/research/diffusion/redwood/09620_pc_partial.ply"
pth_gt = "/viscam/u/alexzhou907/research/diffusion/redwood/09620_pc.ply"
points = np.asarray(o3d.io.read_point_cloud(pth).points)
gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
np.save('gt.npy', gt_points)
test_dataset = ShapeNet15kPointClouds(root_dir=opt.dataroot_pc,
categories=opt.classes, split='train',
tr_sample_size=opt.npoints,
te_sample_size=opt.npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.svpoints, replace=False)]).float()
x = (x - m) / s
x = x[None].transpose(1, 2).cuda()
shape = list(x.shape)
shape[-1] = opt.npoints - shape[-1]
res = []
for v in tqdm(range(20)):
gen_all = netE.gen_samples_traj2(x.cuda(), torch.Size(shape), 'cuda', num_save=55,
clip_denoised=False)
gen_all = torch.stack(gen_all, dim=1).detach().cpu()
gen_all = gen_all.transpose(2, 3).contiguous()
gen_all = gen_all * s[:, None] + m[:, None]
res.append(gen_all[:, -1].cpu())
for p, gen in enumerate(gen_all):
# gen
write_to_xml_batch(
os.path.join(save_dir, 'mode_%03d/xml/gen_process/' % ( v)),
gen.numpy(), cat='chair')
for k, pcl in enumerate(gen_all[:, -1].cpu().numpy()):
dir_ = os.path.join(save_dir, 'mode_%03d/xml/rotate_final/' % ( v))
Path(dir_).mkdir(parents=True, exist_ok=True)
for azim in np.linspace(45, 405 - (360 / 50), 50):
write_to_xml(
os.path.join(dir_, 'azim_%03d.xml' % azim),
pcl, cat='chair', elev=19.471, azim=azim)
pcwrite(os.path.join(save_dir, 'mode_%03d.ply'%v), gen_all[:, -1].cpu().numpy()[0])
pcwrite(os.path.join(save_dir, 'gt.ply'),
gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
generate_video_redwood( netE,opt, outf_syn)
exit()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -0,0 +1,681 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
for v in range(5):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['car'])
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=True)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/3_res32_pc_car_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-03-08-40', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -0,0 +1,753 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
# img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
# img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
# img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
# images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
# images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
del ref_pcs, masked, results
def evaluate_saved(opt, netE, save_dir, logger):
ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = ours_base + '/recon_gt.pth'
ours_pth = ours_base + '/ours_results.pth'
gt = torch.load(gt_pth).permute(1,0,2,3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
results = compute_all_metrics(gt_, ours_, opt.batch_size)
for key, val in results.items():
if i == 0:
all_res[key] = val
else:
all_res[key] += val
pprint(results)
for key, val in all_res.items():
all_res[key] = val / gt.shape[0]
pprint({key: val.mean().item() for key, val in all_res.items()})
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
for v in range(6):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def redwood_demo(opt, netE, save_dir, logger):
import open3d as o3d
pth = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc_partial.ply"
pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/09515_pc.ply"
points = np.asarray(o3d.io.read_point_cloud(pth).points)
gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
np.save('gt.npy', gt_points)
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float()
x = (x-m)/s
x = x.transpose(1,2).cuda()
res = []
for k in range(20):
recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
recon = recon * s+ m
res.append(recon)
res = torch.cat(res, dim=0)
write_to_xml_batch(os.path.join(save_dir, 'xml'),
(res).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'ply'),
(res).numpy())
torch.save(res, os.path.join(save_dir, 'redwood_demo.pth'))
pcwrite(os.path.join(save_dir, 'ply', 'gt.ply'),
gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
write_to_xml_batch(os.path.join(save_dir, 'xml_gt'),
gt_points[None], cat='chair')
exit()
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
if opt.eval_saved:
evaluate_saved(opt, netE, outf_syn, logger)
if opt.eval_redwood:
redwood_demo(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['chair'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/6_res32_pc_chair_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-11-01-19-21-18', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -0,0 +1,634 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=[category], split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=[category],
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, model, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.category)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
m, s = data['mean'].float(), data['std'].float()
recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
del ref_pcs, masked, results
def evaluate_saved(opt, saved_dir):
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = saved_dir + '/recon_gt.pth'
ours_pth = saved_dir + '/ours_results.pth'
gt = torch.load(gt_pth).permute(1,0,2,3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
results = compute_all_metrics(gt_, ours_, opt.batch_size)
for key, val in results.items():
if i == 0:
all_res[key] = val
else:
all_res[key] += val
pprint(results)
for key, val in all_res.items():
all_res[key] = val / gt.shape[0]
pprint({key: val.mean().item() for key, val in all_res.items()})
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
model.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
model.eval()
with torch.no_grad():
logger.info("Resume Path:%s" % opt.model)
resumed_param = torch.load(opt.model)
model.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, model, outf_syn, logger)
if opt.eval_saved:
evaluate_saved(opt, outf_syn)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--category', default='chair')
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=True)
parser.add_argument('--eval_saved', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--model', default='', required=True, help="path to model (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -0,0 +1,599 @@
from pprint import pprint
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.partnet import GANdatasetPartNet
import trimesh
import csv
import numpy as np
import random
from plyfile import PlyData, PlyElement
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
PlyData([el], text=text).write(f)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_dataset(data_root, npoints, category):
train_ds = GANdatasetPartNet('test', data_root, category, npoints)
return train_ds
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['real']
x_all = data['raw']
for j in range(5):
x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1)
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))):
partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1)
rec = d[1]
rid = d[2]
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy())
raw_id = rid.split('.')[0]
save_sample_dir = os.path.join(save_dir, "{}".format(raw_id))
Path(save_sample_dir).mkdir(parents=True, exist_ok=True)
# save input partial shape
if j == 0:
save_path = os.path.join(save_sample_dir, "raw.ply")
write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path)
# save completed shape
save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j))
write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path)
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet',
help='input batch size')
parser.add_argument('--classes', default='Chair')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=True)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -0,0 +1,599 @@
from pprint import pprint
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.partnet import GANdatasetPartNet
import trimesh
import csv
import numpy as np
import random
from plyfile import PlyData, PlyElement
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
PlyData([el], text=text).write(f)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_dataset(data_root, npoints, category):
train_ds = GANdatasetPartNet('test', data_root, category, npoints)
return train_ds
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_dataset(opt.data_root, opt.npoints,opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['real']
x_all = data['raw']
for j in range(5):
x = torch.cat([x_all, gt_all[:, :, opt.svpoints:]], dim=-1)
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
for p, d in enumerate(zip(list(x_all), list(recon), list(data['raw_id']))):
partial = torch.cat([d[0], d[0][:, 0:1].expand(-1, opt.svpoints)], dim=-1)
rec = d[1]
rid = d[2]
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d' % (i, p), 'mode_%03d' % j),
(torch.stack([partial, rec], dim=0)).transpose(1, 2).numpy())
raw_id = rid.split('.')[0]
save_sample_dir = os.path.join(save_dir, "{}".format(raw_id))
Path(save_sample_dir).mkdir(parents=True, exist_ok=True)
# save input partial shape
if j == 0:
save_path = os.path.join(save_sample_dir, "raw.ply")
write_ply((partial.detach().cpu()).transpose(0, 1).numpy(), save_path)
# save completed shape
save_path = os.path.join(save_sample_dir, "fake-z{}.ply".format(j))
write_ply((rec.detach().cpu()).transpose(0, 1).numpy(), save_path)
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/viscam/u/alexzhou907/01DATA/partnet',
help='input batch size')
parser.add_argument('--classes', default='Table')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=True)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=1024)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/epn3d_chair', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -0,0 +1,681 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=0, azim=0, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
Path(os.path.join(save_dir, 'x_%03d'%k)).mkdir(exist_ok=True, parents=True)
Path(os.path.join(save_dir, 'x_ply_%03d' % k)).mkdir(exist_ok=True, parents=True)
for v in range(5):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['airplane'])
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=True)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/airplane_ckpt/', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
set_seed(opt)
main(opt)

View file

@ -0,0 +1,764 @@
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.visualize import *
from utils.file_utils import *
from utils.mitsuba_renderer import write_to_xml_batch
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
#############################################################################
def get_pc_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
return tr_dataset
def get_svr_dataset(pc_dataroot, mesh_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Singleview_Points(root_pc=pc_dataroot, root_mesh=mesh_root,
cache=os.path.join(mesh_root, '../cache'), split='val',
categories=category,
radius=3, elev=-89, azim=180, img_size=512, focal_length=1000,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=category, split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
categories=category,
npoints=npoints, sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return te_dataset
def evaluate_recon_mvr(opt, netE, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
ref = []
samples = []
images = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
# img_all = data['image']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
# img = img_all.reshape(B * V, *img_all.shape[2:])
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
# img = img.reshape(B,V,*img.shape[1:])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj)
# images.append(img)
ref_pcs = torch.cat(ref, dim=0)
sample_pcs = torch.cat(samples, dim=0)
# images = torch.cat(images, dim=0)
masked = torch.cat(masked, dim=0)
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
# torch.save(images.reshape(B,V, *images.shape[2:]), os.path.join(save_dir, 'recon_depth.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
del ref_pcs, masked, results
def evaluate_saved(opt, netE, save_dir, logger):
ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = ours_base + '/recon_gt.pth'
ours_pth = ours_base + '/ours_results.pth'
gt = torch.load(gt_pth).permute(1,0,2,3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
results = compute_all_metrics(gt_, ours_, opt.batch_size)
for key, val in results.items():
if i == 0:
all_res[key] = val
else:
all_res[key] += val
pprint(results)
for key, val in all_res.items():
all_res[key] = val / gt.shape[0]
pprint({key: val.mean().item() for key, val in all_res.items()})
def generate_multimodal(opt, netE, save_dir, logger):
test_dataset = get_svr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
# _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.classes, use_mask=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
gt_all = data['test_points']
x_all = data['sv_points']
img_all = data['image']
for v in range(6):
x = x_all.transpose(1, 2).contiguous()
img = img_all
m, s = data['mean'].float(), data['std'].float()
recon = netE.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
for p, d in enumerate(zip(list(gt_all), list(recon), list(x), list(img))):
im = np.fliplr(np.flipud(d[-1]))
plt.imsave(os.path.join(save_dir, 'depth_%03d_%03d.png'%(i,p)), im, cmap='gray')
write_to_xml_batch(os.path.join(save_dir, 'x_%03d_%03d'%(i,p), 'mode_%03d'%v), (torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy(), cat='chair')
export_to_pc_batch(os.path.join(save_dir, 'x_ply_%03d_%03d'%(i,p), 'mode_%03d'%v),(torch.stack(d[:-1], dim=0)* s[0:1] + m[0:1]).numpy())
def redwood_demo(opt, netE, save_dir, logger):
import open3d as o3d
pth = "/viscam/u/alexzhou907/01DATA/redwood/01605_sample_1.ply"
pth_gt = "/viscam/u/alexzhou907/01DATA/redwood/01605_pc_gt.ply"
points = np.asarray(o3d.io.read_point_cloud(pth).points)
gt_points = np.asarray(o3d.io.read_point_cloud(pth_gt).points)
np.save('gt.npy', gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)])
write_to_xml_batch(os.path.join(save_dir, 'xml_gt'),
gt_points[np.random.choice(gt_points.shape[0], size=opt.npoints, replace=False)][None], cat='table')
test_dataset = get_pc_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.classes)
m, s = torch.from_numpy(test_dataset[0]['mean']).float(), torch.from_numpy(test_dataset[0]['std']).float()
x = torch.from_numpy(points[np.random.choice(points.shape[0], size=opt.npoints, replace=False)]).float()
x = (x-m)/s
x = x[None].transpose(1,2).cuda()
res = []
for k in range(20):
recon = netE.gen_samples(x[:,:,:opt.svpoints], x[:,:,opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = recon.transpose(1, 2).contiguous()
recon = recon * s+ m
res.append(recon)
res = torch.cat(res, dim=0)
write_to_xml_batch(os.path.join(save_dir, 'xml'),
(res).numpy(), cat='table')
export_to_pc_batch(os.path.join(save_dir, 'ply'),
(res).numpy())
torch.save(res, os.path.join(save_dir, 'redwood_demo.pth'))
exit()
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
netE = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.cuda:
netE.cuda()
def _transform_(m):
return nn.parallel.DataParallel(m)
netE = netE.cuda()
netE.multi_gpu_wrapper(_transform_)
netE.eval()
ckpts = [os.path.join(opt.ckpt_dir, f) for f in os.listdir(opt.ckpt_dir) if f.endswith('.pth')]
with torch.no_grad():
for ckpt in reversed(sorted(ckpts, key=lambda x: int(x.strip('.pth').split('_')[-1]) )):
opt.netE = ckpt
logger.info("Resume Path:%s" % opt.netE)
resumed_param = torch.load(opt.netE)
netE.load_state_dict(resumed_param['model_state'])
if opt.eval_recon_mvr:
# Evaluate generation
evaluate_recon_mvr(opt, netE, outf_syn, logger)
if opt.generate_multimodal:
generate_multimodal(opt, netE, outf_syn, logger)
if opt.eval_saved:
evaluate_saved(opt, netE, outf_syn, logger)
if opt.eval_redwood:
redwood_demo(opt, netE, outf_syn, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--classes', default=['table'])
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--eval_recon_mvr', default=False)
parser.add_argument('--generate_multimodal', default=False)
parser.add_argument('--eval_saved', default=False)
parser.add_argument('--eval_redwood', default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# constrain function
parser.add_argument('--constrain_eps', default=0.2)
parser.add_argument('--constrain_steps', type=int, default=1)
parser.add_argument('--ckpt_dir', default='/viscam/u/alexzhou907/research/diffusion/shape_completion/output/9_res32_pc_table_mse_fs_dattng_2e2-1e4_linbeta_0wd_0.1do/2020-12-16-14-09-50', help="path to netE (to continue training)")
'''eval'''
parser.add_argument('--eval_path',
default='/viscam/u/alexzhou907/research/diffusion/shapenet/output/test/2020-10-10-20-11-46/syn/epoch_2799_samples.pth')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args()
if torch.cuda.is_available():
opt.cuda = True
else:
opt.cuda = False
return opt
if __name__ == '__main__':
opt = parse_args()
main(opt)

View file

@ -0,0 +1,841 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import argparse
from model.unet import get_model
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_completion import PVCNN2Base
import torch.distributed as dist
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import ShapeNet_Multiview_Points
'''
some utils
'''
def rotation_matrix(axis, theta):
"""
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = np.asarray(axis)
axis = axis / np.sqrt(np.dot(axis, axis))
a = np.cos(theta / 2.0)
b, c, d = -axis * np.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
def rotate(vertices, faces):
'''
vertices: [numpoints, 3]
'''
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
K = rotation_matrix([0, 0, 1], np.pi).transpose()
v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]]
return v, f
def norm(v, f):
v = (v - v.min())/(v.max() - v.min()) - 0.5
return v, f
def getGradNorm(net):
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
return pNorm, gradNorm
def weights_init(m):
"""
xavier initialization
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1 and m.weight is not None:
torch.nn.init.xavier_normal_(m.weight)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_()
m.bias.data.fill_(0)
'''
models
'''
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
assert x.shape == means.shape == log_scales.shape
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
centered_x = x - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
self.alphas_cumprod_prev = alphas_cumprod_prev.float()
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t[:,:,self.sv_points:].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
Args:
repeat_noise_steps (int): Number of denoising timesteps in which the same noise
is used across the batch. If >= 0, the initial noise is the same for all batch elemements.
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
return (kl, pred_xstart) if return_pred_xstart else kl
def p_losses(self, denoise_fn, data_start, t, noise=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
class PVCNN2(PVCNN2Base):
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t):
B, D,N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t)
return out
def get_loss_iter(self, data, noises=None):
B, D, N = data.shape
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
else:
raise NotImplementedError(schedule_type)
return betas
def get_dataset(dataroot_pc, dataroot_sv, npoints, svpoints, category):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot_pc,
categories=[category], split='train',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
tr_dataset = ShapeNet_Multiview_Points(root_pc=dataroot_pc, root_views=dataroot_sv,
cache=os.path.join(dataroot_pc, '../cache'), split='train',
categories=[category],
npoints=npoints, sv_samples=svpoints,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset
def get_dataloader(opt, train_dataset, test_dataset=None):
if opt.distribution_type == 'multi':
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
if test_dataset is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=opt.world_size,
rank=opt.rank
)
else:
test_sampler = None
else:
train_sampler = None
test_sampler = None
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler,
shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True)
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
else:
test_dataloader = None
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == 'multi':
should_diag = gpu==0
else:
should_diag = True
if should_diag:
outf_syn, = setup_output_subdirs(output_dir, 'syn')
if opt.distribution_type == 'multi':
if opt.dist_url == "env://" and opt.rank == -1:
opt.rank = int(os.environ["RANK"])
base_rank = opt.rank * opt.ngpus_per_node
opt.rank = base_rank + gpu
dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url,
world_size=opt.world_size, rank=opt.rank)
opt.bs = int(opt.bs / opt.ngpus_per_node)
opt.workers = 0
opt.saveIter = int(opt.saveIter / opt.ngpus_per_node)
opt.diagIter = int(opt.diagIter / opt.ngpus_per_node)
opt.vizIter = int(opt.vizIter / opt.ngpus_per_node)
''' data '''
train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
'''
create networks
'''
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
if opt.distribution_type == 'multi': # Multiple processes, single GPU per process
def _transform_(m):
return nn.parallel.DistributedDataParallel(
m, device_ids=[gpu], output_device=gpu)
torch.cuda.set_device(gpu)
model.cuda(gpu)
model.multi_gpu_wrapper(_transform_)
elif opt.distribution_type == 'single':
def _transform_(m):
return nn.parallel.DataParallel(m)
model = model.cuda()
model.multi_gpu_wrapper(_transform_)
elif gpu is not None:
torch.cuda.set_device(gpu)
model = model.cuda(gpu)
else:
raise ValueError('distribution_type = multi | single | None')
if should_diag:
logger.info(opt)
optimizer= optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999))
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.lr_gamma)
if opt.model != '':
ckpt = torch.load(opt.model)
model.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
if opt.model != '':
start_epoch = torch.load(opt.model)['epoch'] + 1
else:
start_epoch = 0
def new_x_chain(x, num_chain):
return torch.randn(num_chain, *x.shape[1:], device=x.device)
for epoch in range(start_epoch, opt.niter):
if opt.distribution_type == 'multi':
train_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
randind = np.random.choice(20) #20 views
x = data['train_points'].transpose(1,2)
sv_x = data['sv_points'][:,randind].transpose(1,2)
sv_x[:,:,opt.svpoints:] = x[:,:,opt.svpoints:]
noises_batch = noises_init[data['idx']].transpose(1,2)
'''
train diffusion
'''
if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None):
sv_x = sv_x.cuda(gpu)
noises_batch = noises_batch.cuda(gpu)
elif opt.distribution_type == 'single':
sv_x = sv_x.cuda()
noises_batch = noises_batch.cuda()
loss = model.get_loss_iter(sv_x, noises_batch).mean()
optimizer.zero_grad()
loss.backward()
netpNorm, netgradNorm = getGradNorm(model)
if opt.grad_clip is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
optimizer.step()
if i % opt.print_freq == 0 and should_diag:
logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, '
'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} '
.format(
epoch, opt.niter, i, len(dataloader),loss.item(),
netpNorm, netgradNorm,
))
if (epoch + 1) % opt.diagIter == 0 and should_diag:
logger.info('Diagnosis:')
x_range = [x.min().item(), x.max().item()]
kl_stats = model.all_kl(sv_x)
logger.info(' [{:>3d}/{:>3d}] '
'x_range: [{:>10.4f}, {:>10.4f}], '
'total_bpd_b: {:>10.4f}, '
'terms_bpd: {:>10.4f}, '
'prior_bpd_b: {:>10.4f} '
'mse_bt: {:>10.4f} '
.format(
epoch, opt.niter,
*x_range,
kl_stats['total_bpd_b'].item(),
kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item()
))
if (epoch + 1) % opt.vizIter == 0 and should_diag:
logger.info('Generation: eval')
model.eval()
m, s = train_dataset.all_points_mean.reshape(1, -1), train_dataset.all_points_std.reshape(1, -1)
with torch.no_grad():
x_gen_eval = model.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu()
gen_stats = [x_gen_eval.mean(), x_gen_eval.std()]
gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()]
logger.info(' [{:>3d}/{:>3d}] '
'eval_gen_range: [{:>10.4f}, {:>10.4f}] '
'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] '
.format(
epoch, opt.niter,
*gen_eval_range, *gen_stats,
))
export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch),
(x_gen_eval.transpose(1, 2)*s+m).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch),
(sv_x.transpose(1, 2).detach().cpu()*s+m).numpy()*3)
export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch),
(sv_x[:,:,:opt.svpoints].transpose(1, 2).detach().cpu()*s+m).numpy()*3)
model.train()
if (epoch + 1) % opt.saveIter == 0:
if should_diag:
save_dict = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict()
}
torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch))
if opt.distribution_type == 'multi':
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
model.load_state_dict(
torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state'])
dist.destroy_process_group()
def main():
opt = parse_args()
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
''' workaround '''
train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.classes)
noises_init = torch.randn(len(train_dataset), opt.npoints-opt.svpoints, opt.nc)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
if opt.distribution_type == 'multi':
opt.ngpus_per_node = torch.cuda.device_count()
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='/viscam/u/alexzhou907/01DATA/shapenet/ShapeNetCore.v2.PC15k', help='input batch size')
parser.add_argument('--dataroot_sv', default='/viscam/u/alexzhou907/01DATA/shapenet/shapenet_mit_preprocessed',
help='input batch size')
parser.add_argument('--category', default='chair')
parser.add_argument('--bs', type=int, default=48, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate for E, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--decay', type=float, default=0, help='weight decay for EBM')
parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM')
parser.add_argument('--lr_gamma', type=float, default=0.998, help='lr decay for EBM')
parser.add_argument('--model', default='', help="path to model (to continue training)")
'''distributed'''
parser.add_argument('--world_size', default=1, type=int,
help='Number of distributed nodes.')
parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist_backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None],
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use. None means using all available GPUs.')
'''eval'''
parser.add_argument('--saveIter', default=100, help='unit: epoch')
parser.add_argument('--diagIter', default=50, help='unit: epoch')
parser.add_argument('--vizIter', default=50, help='unit: epoch')
parser.add_argument('--print_freq', default=50, help='unit: iter')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
main()

0
shapenet/__init__.py Normal file
View file

Some files were not shown because too many files have changed in this diff Show more