style: autoformatting

This commit is contained in:
Laurent FAINSIN 2023-04-11 11:12:58 +02:00
parent d887d74852
commit 2fbfc320f2
44 changed files with 2424 additions and 1831 deletions

View file

@ -1,32 +1,33 @@
from glob import glob
import re
import argparse
import numpy as np
from pathlib import Path
import os
import re
from glob import glob
from pathlib import Path
import numpy as np
def raw_camparam_from_xml(path, pose="lookAt"):
import xml.etree.ElementTree as ET
tree = ET.parse(path)
elm = tree.find("./sensor/transform/" + pose)
camparam = elm.attrib
origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
height = int(
tree.find("./sensor/film/integer[@name='height']").attrib['value'])
width = int(
tree.find("./sensor/film/integer[@name='width']").attrib['value'])
origin = np.fromstring(camparam["origin"], dtype=np.float32, sep=",")
target = np.fromstring(camparam["target"], dtype=np.float32, sep=",")
up = np.fromstring(camparam["up"], dtype=np.float32, sep=",")
height = int(tree.find("./sensor/film/integer[@name='height']").attrib["value"])
width = int(tree.find("./sensor/film/integer[@name='width']").attrib["value"])
camparam = dict()
camparam['origin'] = origin
camparam['up'] = up
camparam['target'] = target
camparam['height'] = height
camparam['width'] = width
camparam["origin"] = origin
camparam["up"] = up
camparam["target"] = target
camparam["height"] = height
camparam["width"] = width
return camparam
def get_cam_pos(origin, target, up):
inward = origin - target
right = np.cross(up, inward)
@ -38,59 +39,54 @@ def get_cam_pos(origin, target, up):
ry /= np.linalg.norm(ry)
rz /= np.linalg.norm(rz)
rot = np.stack([
rx,
ry,
-rz
], axis=0)
aff = np.concatenate([
np.eye(3), -origin[:,None]
], axis=1)
rot = np.stack([rx, ry, -rz], axis=0)
aff = np.concatenate([np.eye(3), -origin[:, None]], axis=1)
ext = np.matmul(rot, aff)
result = np.concatenate(
[ext, np.array([[0,0,0,1]])], axis=0
)
result = np.concatenate([ext, np.array([[0, 0, 0, 1]])], axis=0)
return result
def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png')))
cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths]
depths = sorted(glob(os.path.join(datapoint_dir, "*depth.png")))
cam_ext = [
"_".join(re.sub(dataroot.strip("/"), camera_param_dir.strip("/"), f).split("_")[:-1]) + ".xml" for f in depths
]
for i, (f, pth) in enumerate(zip(cam_ext, depths)):
if not os.path.exists(f):
continue
params=raw_camparam_from_xml(f)
origin, target, up, width, height = params['origin'], params['target'], params['up'],\
params['width'], params['height']
params = raw_camparam_from_xml(f)
origin, target, up, width, height = (
params["origin"],
params["target"],
params["up"],
params["width"],
params["height"],
)
ext_matrix = get_cam_pos(origin, target, up)
#####
diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5
diag = (0.036**2 + 0.024**2) ** 0.5
focal_length = 0.05
res = [480, 480]
h_relative = (res[1] / res[0])
sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2))
h_relative = res[1] / res[0]
sensor_width = np.sqrt(diag**2 / (1 + h_relative**2))
pix_size = sensor_width / res[0]
K = np.array([
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
[0, 0, 1]
])
K = np.array(
[
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
[0, 0, 1],
]
)
np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K)
np.savez(pth.split("depth.png")[0] + "cam_params.npz", extr=ext_matrix, intr=K)
def main(opt):
@ -102,21 +98,16 @@ def main(opt):
if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
leaf_subdirs.append(dirpath)
for k, dir_ in enumerate(leaf_subdirs):
print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_))
print("Processing dir {}/{}: {}".format(k, len(leaf_subdirs), dir_))
convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
if __name__ == '__main__':
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument('--dataroot', type=str, default='GenReData/')
args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2')
args.add_argument("--dataroot", type=str, default="GenReData/")
args.add_argument("--mitsuba_xml_root", type=str, default="GenReData/genre-xml_v2")
opt = args.parse_args()

View file

@ -1,11 +1,13 @@
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import os
import json
import os
import random
import numpy as np
import torch
import trimesh
from plyfile import PlyData, PlyElement
from torch.utils.data import Dataset
def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image
@ -26,29 +28,32 @@ def project_pc_to_image(points, resolution=64):
def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
with open(filename, mode='wb') as f:
"""input: Nx3, write points to filename as PLY format."""
points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
el = PlyElement.describe(vertex, "vertex", comments=["vertices"])
with open(filename, mode="wb") as f:
PlyData([el], text=text).write(f)
def rotate_point_cloud(points, transformation_mat):
new_points = np.dot(transformation_mat, points.T).T
return new_points
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
""" align 3depn shapes to shapenet coordinates"""
"""align 3depn shapes to shapenet coordinates"""
# angle = math.radians(angle_deg)
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
# rot_m = rot_m.to_matrix()
rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00],
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
[-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]])
rot_m = np.array(
[
[2.22044605e-16, 0.00000000e00, 1.00000000e00],
[0.00000000e00, 1.00000000e00, 0.00000000e00],
[-1.00000000e00, 0.00000000e00, 2.22044605e-16],
]
)
new_points = rotate_point_cloud(points, rot_m)
@ -87,14 +92,13 @@ def sample_point_cloud_by_n(points, n_pts):
return points
def collect_data_id(split_dir, classname, phase):
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
if not os.path.exists(filename):
raise ValueError("Invalid filepath: {}".format(filename))
all_ids = []
with open(filename, 'r') as fp:
with open(filename, "r") as fp:
info = json.load(fp)
for item in info:
all_ids.append(item["anno_id"])
@ -102,7 +106,6 @@ def collect_data_id(split_dir, classname, phase):
return all_ids
class GANdatasetPartNet(Dataset):
def __init__(self, phase, data_root, category, n_pts):
super(GANdatasetPartNet, self).__init__()
@ -114,10 +117,12 @@ class GANdatasetPartNet(Dataset):
self.data_root = data_root
shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase)
shape_names = collect_data_id(
os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase
)
self.shape_names = []
for name in shape_names:
path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name)
path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name)
if os.path.exists(path):
self.shape_names.append(name)
@ -129,12 +134,12 @@ class GANdatasetPartNet(Dataset):
@staticmethod
def load_point_cloud(path):
pc = trimesh.load(path)
pc = pc.vertices / 2.0 # scale to unit sphere
pc = pc.vertices / 2.0 # scale to unit sphere
return pc
@staticmethod
def read_point_cloud_part_label(path):
with open(path, 'r') as fp:
with open(path, "r") as fp:
labels = fp.readlines()
labels = np.array([int(x) for x in labels])
return labels
@ -156,26 +161,31 @@ class GANdatasetPartNet(Dataset):
def __getitem__(self, index):
raw_shape_name = self.shape_names[index]
raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply')
raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply")
raw_pc = self.load_point_cloud(raw_ply_path)
raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt')
raw_label_path = os.path.join(
self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt"
)
part_labels = self.read_point_cloud_part_label(raw_label_path)
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
real_shape_name = self.shape_names[index]
real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply')
real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply")
real_pc = self.load_point_cloud(real_ply_path)
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name,
'n_part_keep': n_part_keep, 'idx': index}
return {
"raw": raw_pc,
"real": real_pc,
"raw_id": raw_shape_name,
"real_id": real_shape_name,
"n_part_keep": n_part_keep,
"idx": index,
}
def __len__(self):
return len(self.shape_names)

View file

@ -1,33 +1,67 @@
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
import random
import numpy as np
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
"02691156": "airplane",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02747177": "can",
"02942699": "camera",
"02954340": "cap",
"02958343": "car",
"03001627": "chair",
"03046257": "clock",
"03207941": "dishwasher",
"03211117": "monitor",
"04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
@ -35,13 +69,23 @@ cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class Uniform15KPC(Dataset):
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
te_sample_size=10000, split='train', scale=1.,
normalize_per_shape=False, box_per_shape=False,
random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None, all_points_std=None,
input_dim=3, use_mask=False):
def __init__(
self,
root_dir,
subdirs,
tr_sample_size=10000,
te_sample_size=10000,
split="train",
scale=1.0,
normalize_per_shape=False,
box_per_shape=False,
random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None,
all_points_std=None,
input_dim=3,
use_mask=False,
):
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
@ -67,9 +111,9 @@ class Uniform15KPC(Dataset):
all_mids = []
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
if not x.endswith(".npy"):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
all_mids.append(os.path.join(self.split, x[: -len(".npy")]))
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids:
@ -111,7 +155,9 @@ class Uniform15KPC(Dataset):
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim)
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(
axis=1
).reshape(B, 1, input_dim)
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
@ -129,8 +175,7 @@ class Uniform15KPC(Dataset):
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points))
print("Min number of points: (train)%d (test)%d"
% (self.tr_sample_size, self.te_sample_size))
print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size))
assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx):
@ -139,7 +184,6 @@ class Uniform15KPC(Dataset):
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
@ -173,11 +217,14 @@ class Uniform15KPC(Dataset):
sid, mid = self.all_cate_mids[idx]
out = {
'idx': idx,
'train_points': tr_out,
'test_points': te_out,
'mean': m, 'std': s, 'cate_idx': cate_idx,
'sid': sid, 'mid': mid
"idx": idx,
"train_points": tr_out,
"test_points": te_out,
"mean": m,
"std": s,
"cate_idx": cate_idx,
"sid": sid,
"mid": mid,
}
if self.use_mask:
@ -192,26 +239,35 @@ class Uniform15KPC(Dataset):
# out['train_points_masked'] = masked
# out['train_masks'] = tr_mask
tr_mask = self.mask_transform(tr_out)
out['train_masks'] = tr_mask
out["train_masks"] = tr_mask
return out
class ShapeNet15kPointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False, box_per_shape=False,
random_subsample=False,
all_points_mean=None, all_points_std=None,
use_mask=False):
def __init__(
self,
root_dir="data/ShapeNetCore.v2.PC15k",
categories=["airplane"],
tr_sample_size=10000,
te_sample_size=2048,
split="train",
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
box_per_shape=False,
random_subsample=False,
all_points_mean=None,
all_points_std=None,
use_mask=False,
):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test', 'val']
assert self.split in ["train", "test", "val"]
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
self.cates = categories
if 'all' in categories:
if "all" in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
@ -221,19 +277,21 @@ class ShapeNet15kPointClouds(Uniform15KPC):
self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__(
root_dir, self.synset_ids,
root_dir,
self.synset_ids,
tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size,
split=split, scale=scale,
normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape,
split=split,
scale=scale,
normalize_per_shape=normalize_per_shape,
box_per_shape=box_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3, use_mask=use_mask)
all_points_mean=all_points_mean,
all_points_std=all_points_std,
input_dim=3,
use_mask=use_mask,
)
####################################################################################

View file

@ -1,34 +1,70 @@
import hashlib
import os
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import os
import numpy as np
import hashlib
import torch
import matplotlib.pyplot as plt
synset_to_label = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
'04554684': 'washer', '02992529': 'cellphone',
'02843684': 'birdhouse', '02871439': 'bookshelf',
"02691156": "airplane",
"02773838": "bag",
"02801938": "basket",
"02808440": "bathtub",
"02818832": "bed",
"02828884": "bench",
"02876657": "bottle",
"02880940": "bowl",
"02924116": "bus",
"02933112": "cabinet",
"02747177": "can",
"02942699": "camera",
"02954340": "cap",
"02958343": "car",
"03001627": "chair",
"03046257": "clock",
"03207941": "dishwasher",
"03211117": "monitor",
"04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
@ -36,30 +72,44 @@ synset_to_label = {
# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()}
def _convert_categories(categories):
assert categories is not None, 'List of categories cannot be empty!'
if not (c in synset_to_label.keys() + label_to_synset.keys()
for c in categories):
warnings.warn('Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable.')
synsets = [label_to_synset[c] if c in label_to_synset.keys()
else c for c in categories]
assert categories is not None, "List of categories cannot be empty!"
if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories):
warnings.warn(
"Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable."
)
synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories]
return synsets
class ShapeNet_Multiview_Points(Dataset):
def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val',
npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False):
def __init__(
self,
root_pc: str,
root_views: str,
cache: str,
categories: list = ["chair"],
split: str = "val",
npoints=2048,
sv_samples=800,
all_points_mean=None,
all_points_std=None,
get_image=False,
):
self.root = Path(root_views)
self.split = split
self.get_image = get_image
params = {
'cat': categories,
'npoints': npoints,
'sv_samples': sv_samples,
"cat": categories,
"npoints": npoints,
"sv_samples": sv_samples,
}
params = tuple(sorted(pair for pair in params.items()))
self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest())
self.cache_dir = Path(cache) / "svpoints/{}/{}".format(
"_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest()
)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.paths = []
@ -74,13 +124,12 @@ class ShapeNet_Multiview_Points(Dataset):
# loops through desired classes
for i in range(len(self.synsets)):
syn = self.synsets[i]
class_target = self.root / syn
if not class_target.exists():
raise ValueError('Class {0} ({1}) was not found at location {2}.'.format(
syn, self.labels[i], str(class_target)))
raise ValueError(
"Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target))
)
sub_path_pc = os.path.join(root_pc, syn, split)
if not os.path.isdir(sub_path_pc):
@ -90,30 +139,30 @@ class ShapeNet_Multiview_Points(Dataset):
self.all_mids = []
self.imgs = []
for x in os.listdir(sub_path_pc):
if not x.endswith('.npy'):
if not x.endswith(".npy"):
continue
self.all_mids.append(os.path.join(split, x[:-len('.npy')]))
self.all_mids.append(os.path.join(split, x[: -len(".npy")]))
for mid in tqdm(self.all_mids):
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz'))
cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz"))
if len(cams_pths) < 20:
continue
point_cloud = np.load(obj_fname)
sv_points_group = []
img_path_group = []
(self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True)
(self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True)
success = True
for i, cp in enumerate(cams_pths):
cp = str(cp)
vp = cp.split('cam_params')[0] + 'depth.png'
depth_minmax_pth = cp.split('_cam_params')[0] + '.npy'
cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) )
vp = cp.split("cam_params")[0] + "depth.png"
depth_minmax_pth = cp.split("_cam_params")[0] + ".npy"
cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth))
cam_params = np.load(cp)
extr = cam_params['extr']
intr = cam_params['intr']
extr = cam_params["extr"]
intr = cam_params["intr"]
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
@ -125,7 +174,7 @@ class ShapeNet_Multiview_Points(Dataset):
sv_points_group.append(sv_point_cloud)
except Exception as e:
print(e)
success=False
success = False
break
if not success:
continue
@ -144,64 +193,66 @@ class ShapeNet_Multiview_Points(Dataset):
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:,:10000]
self.test_points = self.all_points[:,10000:]
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
def get_pc_stats(self, idx):
return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)
return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1)
def __len__(self):
"""Returns the length of the dataset. """
"""Returns the length of the dataset."""
return len(self.all_points)
def __getitem__(self, index):
tr_out = self.train_points[index]
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
tr_out = tr_out[tr_idxs, :]
gt_points = self.test_points[index][:self.npoints]
gt_points = self.test_points[index][: self.npoints]
m, s = self.get_pc_stats(index)
sv_points = self.all_points_sv[index]
idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False)
idxs = np.arange(0, sv_points.shape[-2])[
: self.sv_samples
] # np.random.choice(sv_points.shape[0], 500, replace=False)
data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1)
data = torch.cat(
[
torch.from_numpy(sv_points[:, idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]),
],
dim=1,
)
masks = torch.zeros_like(data)
masks[:,:idxs.shape[0]] = 1
masks[:, : idxs.shape[0]] = 1
res = {'train_points': torch.from_numpy(tr_out).float(),
'test_points': torch.from_numpy(gt_points).float(),
'sv_points': data,
'masks': masks,
'std': s, 'mean': m,
'idx': index,
'name':self.all_mids[index]
}
if self.split != 'train' and self.get_image:
res = {
"train_points": torch.from_numpy(tr_out).float(),
"test_points": torch.from_numpy(gt_points).float(),
"sv_points": data,
"masks": masks,
"std": s,
"mean": m,
"idx": index,
"name": self.all_mids[index],
}
if self.split != "train" and self.get_image:
img_lst = []
for n in range(self.all_points_sv.shape[1]):
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3]
img_lst.append(img)
img = torch.stack(img_lst, dim=0)
res['image'] = img
res["image"] = img
return res
def _render(self, cache_path, depth_pth, depth_minmax_pth):
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
#
@ -210,11 +261,9 @@ class ShapeNet_Multiview_Points(Dataset):
if os.path.exists(cache_path):
data = np.load(cache_path)
else:
data, depth = self.transform(depth_pth, depth_minmax_pth)
assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0])
assert data.shape[0] > 600, "Only {} points found".format(data.shape[0])
data = data[np.random.choice(data.shape[0], 600, replace=False)]
np.save(cache_path, data)
return data

View file

@ -1,25 +1,32 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
import torch
from torch import nn
from torch.autograd import Function
chamfer_found = importlib.find_loader("chamfer_2D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 2D")
from torch.utils.cpp_extension import load
chamfer_2D = load(name="chamfer_2D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
])
chamfer_2D = load(
name="chamfer_2D",
sources=[
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer2D.cu"]),
],
)
print("Loaded JIT 2D CUDA chamfer distance")
else:
import chamfer_2D
print("Loaded compiled 2D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_2DFunction(Function):
@ -57,9 +64,7 @@ class chamfer_2DFunction(Function):
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_2D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
chamfer_2D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
return gradxyz1, gradxyz2

View file

@ -2,15 +2,16 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_2D',
name="chamfer_2D",
ext_modules=[
CUDAExtension('chamfer_2D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
]),
CUDAExtension(
"chamfer_2D",
[
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(__file__.split("/")[:-1] + ["chamfer2D.cu"]),
],
),
],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
cmdclass={
'build_ext': BuildExtension
})
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
cmdclass={"build_ext": BuildExtension},
)

View file

@ -1,25 +1,30 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
import torch
from torch import nn
from torch.autograd import Function
chamfer_found = importlib.find_loader("chamfer_3D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 3D")
from torch.utils.cpp_extension import load
chamfer_3D = load(name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],)
chamfer_3D = load(
name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer3D.cu"]),
],
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
)
print("Loaded JIT 3D CUDA chamfer distance")
else:
import chamfer_3D
print("Loaded compiled 3D CUDA chamfer distance")
@ -60,9 +65,7 @@ class chamfer_3DFunction(Function):
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_3D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
chamfer_3D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
return gradxyz1, gradxyz2
@ -74,4 +77,3 @@ class chamfer_3DDist(nn.Module):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_3DFunction.apply(input1, input2)

View file

@ -2,15 +2,16 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_3D',
name="chamfer_3D",
ext_modules=[
CUDAExtension('chamfer_3D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
]),
CUDAExtension(
"chamfer_3D",
[
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(__file__.split("/")[:-1] + ["chamfer3D.cu"]),
],
),
],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
cmdclass={
'build_ext': BuildExtension
})
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
cmdclass={"build_ext": BuildExtension},
)

View file

@ -1,24 +1,29 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
import torch
from torch import nn
from torch.autograd import Function
chamfer_found = importlib.find_loader("chamfer_5D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 5D")
from torch.utils.cpp_extension import load
chamfer_5D = load(name="chamfer_5D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]),
])
chamfer_5D = load(
name="chamfer_5D",
sources=[
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer5D.cu"]),
],
)
print("Loaded JIT 5D CUDA chamfer distance")
else:
import chamfer_5D
print("Loaded compiled 5D CUDA chamfer distance")
@ -59,9 +64,7 @@ class chamfer_5DFunction(Function):
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_5D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
chamfer_5D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
return gradxyz1, gradxyz2

View file

@ -2,15 +2,16 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_5D',
name="chamfer_5D",
ext_modules=[
CUDAExtension('chamfer_5D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']),
]),
CUDAExtension(
"chamfer_5D",
[
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(__file__.split("/")[:-1] + ["chamfer5D.cu"]),
],
),
],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
cmdclass={
'build_ext': BuildExtension
})
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
cmdclass={"build_ext": BuildExtension},
)

View file

@ -33,8 +33,7 @@ def distChamfer(a, b):
xx = torch.pow(x, 2).sum(2)
yy = torch.pow(y, 2).sum(2)
zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
P = rx.transpose(2, 1) + ry - 2 * zz
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()

View file

@ -1,5 +1,6 @@
import torch
def fscore(dist1, dist2, threshold=0.001):
"""
Calculates the F-score between two point clouds with the corresponding threshold value.
@ -14,4 +15,3 @@ def fscore(dist1, dist2, threshold=0.001):
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
fscore[torch.isnan(fscore)] = 0
return fscore, precision_1, precision_2

View file

@ -1,20 +1,23 @@
import torch, time
import time
import chamfer2D.dist_chamfer_2D
import chamfer3D.dist_chamfer_3D
import chamfer5D.dist_chamfer_5D
import chamfer_python
import torch
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
from torch.autograd import Variable
from fscore import fscore
from torch.autograd import Variable
def test_chamfer(distChamfer, dim):
points1 = torch.rand(4, 100, dim).cuda()
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
dist1, dist2, idx1, idx2 = distChamfer(points1, points2)
loss = torch.sum(dist1)
loss.backward()
@ -29,9 +32,9 @@ def test_chamfer(distChamfer, dim):
xd1 = idx1 - myidx1
xd2 = idx2 - myidx2
assert (
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
), "chamfer cuda and chamfer normal are not giving the same results"
print(f"fscore :", fscore(dist1, dist2))
print("fscore :", fscore(dist1, dist2))
print("Unit test passed")
@ -49,7 +52,6 @@ def timings(distChamfer, dim):
loss.backward()
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
print("Timings : Start Pythonic version")
start = time.time()
for i in range(num_it):
@ -61,9 +63,8 @@ def timings(distChamfer, dim):
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
dims = [2,3,5]
for i,cham in enumerate([cham2D, cham3D, cham5D]):
dims = [2, 3, 5]
for i, cham in enumerate([cham2D, cham3D, cham5D]):
print(f"testing Chamfer {dims[i]}D")
test_chamfer(cham, dims[i])
timings(cham, dims[i])

View file

@ -1,5 +1,5 @@
import torch
import emd_cuda
import torch
class EarthMoverDistanceFunction(torch.autograd.Function):
@ -44,4 +44,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
cost = cost / xyz1.shape[1]
return cost

View file

@ -9,19 +9,17 @@ Notes:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
name="emd_ext",
ext_modules=[
CUDAExtension(
name='emd_cuda',
name="emd_cuda",
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
"cuda/emd.cpp",
"cuda/emd_kernel.cu",
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
),
],
cmdclass={
'build_ext': BuildExtension
})
cmdclass={"build_ext": BuildExtension},
)

View file

@ -1,6 +1,5 @@
import torch
import numpy as np
import time
import torch
from emd import earth_mover_distance
# gt
@ -13,10 +12,12 @@ print(p2)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist = (
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
)
print("gt_dist: ", gt_dist)
gt_dist.backward()
print(p1.grad)
@ -41,4 +42,3 @@ print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)

View file

@ -1,17 +1,19 @@
import torch
import numpy as np
import warnings
import numpy as np
import torch
from numpy.linalg import norm
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
from metrics.ChamferDistancePytorch.fscore import fscore
from tqdm import tqdm
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
from metrics.ChamferDistancePytorch.fscore import fscore
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
cham3D = chamfer_3DDist()
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
x, y = a, b
@ -22,11 +24,11 @@ def distChamfer(a, b):
diag_ind = torch.arange(0, num_points).to(a).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
P = rx.transpose(2, 1) + ry - 2 * zz
return P.min(1)[0], P.min(2)[0]
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
@ -56,13 +58,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
cd = torch.cat(cd_lst)
emd = torch.cat(emd_lst)
fs_lst = torch.cat(fs_lst).mean()
results = {
'MMD-CD': cd,
'MMD-EMD': emd,
'fscore': fs_lst
}
results = {"MMD-CD": cd, "MMD-EMD": emd, "fscore": fs_lst}
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
@ -107,7 +106,7 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
if sqrt:
M = M.abs().sqrt()
INFINITY = float('inf')
INFINITY = float("inf")
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
count = torch.zeros(n0 + n1).to(Mxx)
@ -116,19 +115,21 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
s = {
'tp': (pred * label).sum(),
'fp': (pred * (1 - label)).sum(),
'fn': ((1 - pred) * label).sum(),
'tn': ((1 - pred) * (1 - label)).sum(),
"tp": (pred * label).sum(),
"fp": (pred * (1 - label)).sum(),
"fn": ((1 - pred) * label).sum(),
"tn": ((1 - pred) * (1 - label)).sum(),
}
s.update({
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
'acc': torch.eq(label, pred).float().mean(),
})
s.update(
{
"precision": s["tp"] / (s["tp"] + s["fp"] + 1e-10),
"recall": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
"acc_t": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
"acc_f": s["tn"] / (s["tn"] + s["fp"] + 1e-10),
"acc": torch.eq(label, pred).float().mean(),
}
)
return s
@ -141,9 +142,9 @@ def lgan_mmd_cov(all_dist):
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
cov = torch.tensor(cov).to(all_dist)
return {
'lgan_mmd': mmd,
'lgan_cov': cov,
'lgan_mmd_smp': mmd_smp,
"lgan_mmd": mmd,
"lgan_cov": cov,
"lgan_mmd_smp": mmd_smp,
}
@ -153,27 +154,19 @@ def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({
"%s-CD" % k: v for k, v in res_cd.items()
})
results.update({"%s-CD" % k: v for k, v in res_cd.items()})
res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({
"%s-EMD" % k: v for k, v in res_emd.items()
})
results.update({"%s-EMD" % k: v for k, v in res_emd.items()})
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
# 1-NN results
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
results.update({"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if "acc" in k})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
results.update({"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if "acc" in k})
return results
@ -227,11 +220,11 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit cube.')
warnings.warn("Point-clouds are not in unit cube.")
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit sphere.')
warnings.warn("Point-clouds are not in unit sphere.")
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3)
@ -260,9 +253,9 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.')
raise ValueError("Negative values.")
if len(P) != len(Q):
raise ValueError('Non equal size.')
raise ValueError("Non equal size.")
P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q)
@ -275,7 +268,7 @@ def jensen_shannon_divergence(P, Q):
res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.')
warnings.warn("Numerical values of two JSD methods don't agree.")
return res
@ -312,11 +305,9 @@ if __name__ == "__main__":
r_dist = min_r.mean().cpu().detach().item()
print(l_dist, r_dist)
emd_batch = EMD(x.cuda(), y.cuda(), False)
print(emd_batch.shape)
print(emd_batch.mean().detach().item())
jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy())
print(jsd)

View file

@ -1,13 +1,14 @@
import functools
import torch.nn as nn
import torch
import numpy as np
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
import torch
import torch.nn as nn
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
def create_pointnet_components(
blocks,
in_channels,
embed_dim,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
block = functools.partial(
PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
with_se=with_se,
normalize=normalize,
eps=eps,
)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
layers.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
def create_pointnet2_sa_components(
sa_blocks,
extra_feature_channels,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0
attention = (c + 1) % 2 == 0 and c > 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se and not attention, with_se_relu=True,
normalize=normalize, eps=eps)
block = functools.partial(
PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se and not attention,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
if c == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
elif k == 0:
sa_blocks.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
block = functools.partial(
PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
)
sa_blocks.append(
block(
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
out_channels=out_channels,
include_coordinates=True,
)
)
c += 1
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
if len(sa_blocks) == 1:
@ -127,10 +165,20 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
def create_pointnet2_fp_modules(
fp_blocks,
in_channels,
sa_in_channels,
sv_points,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
@ -139,7 +187,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
PointNetFPModule(
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
)
)
in_channels = out_channels[-1]
@ -151,9 +201,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps)
block = functools.partial(
PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se and not attention,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
@ -168,9 +226,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
class PVCNN2Base(nn.Module):
def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1,
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
def __init__(
self,
num_classes,
sv_points,
embed_dim,
use_att,
dropout=0.1,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__()
assert extra_feature_channels >= 0
self.embed_dim = embed_dim
@ -178,9 +244,14 @@ class PVCNN2Base(nn.Module):
self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
sa_blocks=self.sa_blocks,
extra_feature_channels=extra_feature_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
)
self.sa_layers = nn.ModuleList(sa_layers)
@ -189,16 +260,26 @@ class PVCNN2Base(nn.Module):
# only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points,
with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
fp_blocks=self.fp_blocks,
in_channels=channels_sa_features,
sa_in_channels=sa_in_channels,
sv_points=sv_points,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
)
self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes],
classifier=True, dim=2, width_multiplier=width_multiplier)
layers, _ = create_mlp_components(
in_channels=channels_fp_features,
out_channels=[128, 0.5, num_classes],
classifier=True,
dim=2,
width_multiplier=width_multiplier,
)
self.classifier = nn.Sequential(*layers)
self.embedf = nn.Sequential(
@ -223,31 +304,30 @@ class PVCNN2Base(nn.Module):
return emb
def forward(self, inputs, t):
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
# inputs : [B, in_channels + S, N]
coords, features = inputs[:, :3, :].contiguous(), inputs
coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers):
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i == 0:
features, coords, temb = sa_blocks ((features, coords, temb))
features, coords, temb = sa_blocks((features, coords, temb))
else:
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None:
features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers):
for fp_idx, fp_blocks in enumerate(self.fp_layers):
jump_coords = coords_list[-1 - fp_idx]
fump_feats = in_features_list[-1-fp_idx]
fump_feats = in_features_list[-1 - fp_idx]
# if fp_idx == len(self.fp_layers) - 1:
# jump_coords = jump_coords[:,:,self.sv_points:]
# fump_feats = fump_feats[:,:,self.sv_points:]
features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb))
features, coords, temb = fp_blocks(
(jump_coords, coords, torch.cat([features, temb], dim=1), fump_feats, temb)
)
return self.classifier(features)

View file

@ -1,13 +1,14 @@
import functools
import torch.nn as nn
import torch
import numpy as np
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
import torch
import torch.nn as nn
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
def create_pointnet_components(
blocks,
in_channels,
embed_dim,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
block = functools.partial(
PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
with_se=with_se,
normalize=normalize,
eps=eps,
)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
layers.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
def create_pointnet2_sa_components(
sa_blocks,
extra_feature_channels,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and use_att and p == 0
attention = (c + 1) % 2 == 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
block = functools.partial(
PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
if c == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
elif k == 0:
sa_blocks.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
block = functools.partial(
PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
)
sa_blocks.append(
block(
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
out_channels=out_channels,
include_coordinates=True,
)
)
c += 1
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
if len(sa_blocks) == 1:
@ -127,10 +165,19 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
def create_pointnet2_fp_modules(
fp_blocks,
in_channels,
sa_in_channels,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
@ -139,7 +186,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
PointNetFPModule(
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
)
)
in_channels = out_channels[-1]
@ -147,14 +196,21 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
attention = (c + 1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
block = functools.partial(
PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
@ -168,20 +224,31 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
return fp_layers, in_channels
class PVCNN2Base(nn.Module):
def __init__(self, num_classes, embed_dim, use_att, dropout=0.1,
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
def __init__(
self,
num_classes,
embed_dim,
use_att,
dropout=0.1,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__()
assert extra_feature_channels >= 0
self.embed_dim = embed_dim
self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
sa_blocks=self.sa_blocks,
extra_feature_channels=extra_feature_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
)
self.sa_layers = nn.ModuleList(sa_layers)
@ -190,15 +257,25 @@ class PVCNN2Base(nn.Module):
# only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
fp_blocks=self.fp_blocks,
in_channels=channels_sa_features,
sa_in_channels=sa_in_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
)
self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5
classifier=True, dim=2, width_multiplier=width_multiplier)
layers, _ = create_mlp_components(
in_channels=channels_fp_features,
out_channels=[128, dropout, num_classes], # was 0.5
classifier=True,
dim=2,
width_multiplier=width_multiplier,
)
self.classifier = nn.Sequential(*layers)
self.embedf = nn.Sequential(
@ -223,25 +300,30 @@ class PVCNN2Base(nn.Module):
return emb
def forward(self, inputs, t):
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
# inputs : [B, in_channels + S, N]
coords, features = inputs[:, :3, :].contiguous(), inputs
coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers):
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i == 0:
features, coords, temb = sa_blocks ((features, coords, temb))
features, coords, temb = sa_blocks((features, coords, temb))
else:
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None:
features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers):
features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb))
for fp_idx, fp_blocks in enumerate(self.fp_layers):
features, coords, temb = fp_blocks(
(
coords_list[-1 - fp_idx],
coords,
torch.cat([features, temb], dim=1),
in_features_list[-1 - fp_idx],
temb,
)
)
return self.classifier(features)

View file

@ -1,8 +0,0 @@
from modules.ball_query import BallQuery
from modules.frustum import FrustumPointNetLoss
from modules.loss import KLLoss
from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
from modules.pvconv import PVConv, Attention, Swish, PVConvReLU
from modules.se import SE3d
from modules.shared_mlp import SharedMLP
from modules.voxelization import Voxelization

View file

@ -3,7 +3,7 @@ import torch.nn as nn
import modules.functional as F
__all__ = ['BallQuery']
__all__ = ["BallQuery"]
class BallQuery(nn.Module):
@ -21,7 +21,7 @@ class BallQuery(nn.Module):
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
if points_features is None:
assert self.include_coordinates, 'No Features For Grouping'
assert self.include_coordinates, "No Features For Grouping"
neighbor_features = neighbor_coordinates
else:
neighbor_features = F.grouping(points_features, neighbor_indices)
@ -30,5 +30,6 @@ class BallQuery(nn.Module):
return neighbor_features, F.grouping(temb, neighbor_indices)
def extra_repr(self):
return 'radius={}, num_neighbors={}{}'.format(
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
return "radius={}, num_neighbors={}{}".format(
self.radius, self.num_neighbors, ", include coordinates" if self.include_coordinates else ""
)

View file

@ -5,12 +5,20 @@ import torch.nn.functional as F
import modules.functional as PF
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
__all__ = ["FrustumPointNetLoss", "get_box_corners_3d"]
class FrustumPointNetLoss(nn.Module):
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
def __init__(
self,
num_heading_angle_bins,
num_size_templates,
size_templates,
box_loss_weight=1.0,
corners_loss_weight=10.0,
heading_residual_loss_weight=20.0,
size_residual_loss_weight=20.0,
):
super().__init__()
self.box_loss_weight = box_loss_weight
self.corners_loss_weight = corners_loss_weight
@ -19,28 +27,28 @@ class FrustumPointNetLoss(nn.Module):
self.num_heading_angle_bins = num_heading_angle_bins
self.num_size_templates = num_size_templates
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
self.register_buffer("size_templates", size_templates.view(self.num_size_templates, 3))
self.register_buffer(
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
"heading_angle_bin_centers", torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
)
def forward(self, inputs, targets):
mask_logits = inputs['mask_logits'] # (B, 2, N)
center_reg = inputs['center_reg'] # (B, 3)
center = inputs['center'] # (B, 3)
heading_scores = inputs['heading_scores'] # (B, NH)
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
heading_residuals = inputs['heading_residuals'] # (B, NH)
size_scores = inputs['size_scores'] # (B, NS)
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
size_residuals = inputs['size_residuals'] # (B, NS, 3)
mask_logits = inputs["mask_logits"] # (B, 2, N)
center_reg = inputs["center_reg"] # (B, 3)
center = inputs["center"] # (B, 3)
heading_scores = inputs["heading_scores"] # (B, NH)
heading_residuals_normalized = inputs["heading_residuals_normalized"] # (B, NH)
heading_residuals = inputs["heading_residuals"] # (B, NH)
size_scores = inputs["size_scores"] # (B, NS)
size_residuals_normalized = inputs["size_residuals_normalized"] # (B, NS, 3)
size_residuals = inputs["size_residuals"] # (B, NS, 3)
mask_logits_target = targets['mask_logits'] # (B, N)
center_target = targets['center'] # (B, 3)
heading_bin_id_target = targets['heading_bin_id'] # (B, )
heading_residual_target = targets['heading_residual'] # (B, )
size_template_id_target = targets['size_template_id'] # (B, )
size_residual_target = targets['size_residual'] # (B, 3)
mask_logits_target = targets["mask_logits"] # (B, N)
center_target = targets["center"] # (B, 3)
heading_bin_id_target = targets["heading_bin_id"] # (B, )
heading_residual_target = targets["heading_residual"] # (B, )
size_template_id_target = targets["size_template_id"] # (B, )
size_residual_target = targets["size_residual"] # (B, 3)
batch_size = center.size(0)
batch_id = torch.arange(batch_size, device=center.device)
@ -65,25 +73,32 @@ class FrustumPointNetLoss(nn.Module):
)
# Bounding box losses
heading = (heading_residuals[batch_id, heading_bin_id_target]
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
heading = (
heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target]
) # (B, )
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
size = (size_residuals[batch_id, size_template_id_target]
+ self.size_templates[size_template_id_target]) # (B, 3)
size = (
size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target]
) # (B, 3)
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
sizes=size_target, with_flip=True) # (B, 3, 8)
corners_loss = PF.huber_loss(torch.min(
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
), delta=1.0)
corners_target, corners_target_flip = get_box_corners_3d(
centers=center_target, headings=heading_target, sizes=size_target, with_flip=True
) # (B, 3, 8)
corners_loss = PF.huber_loss(
torch.min(torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)),
delta=1.0,
)
# Summing up
loss = mask_loss + self.box_loss_weight * (
center_loss + center_reg_loss + heading_loss + size_loss
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
+ self.size_residual_loss_weight * size_residual_normalized_loss
+ self.corners_loss_weight * corners_loss
center_loss
+ center_reg_loss
+ heading_loss
+ size_loss
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
+ self.size_residual_loss_weight * size_residual_normalized_loss
+ self.corners_loss_weight * corners_loss
)
return loss
@ -105,9 +120,9 @@ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
l = sizes[:, 0] # (N,)
w = sizes[:, 1] # (N,)
h = sizes[:, 2] # (N,)
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
c = torch.cos(headings) # (N,)
s = torch.sin(headings) # (N,)

View file

@ -1,7 +0,0 @@
from modules.functional.ball_query import ball_query
from modules.functional.devoxelization import trilinear_devoxelize
from modules.functional.grouping import grouping
from modules.functional.interpolatation import nearest_neighbor_interpolate
from modules.functional.loss import kl_loss, huber_loss
from modules.functional.sampling import gather, furthest_point_sample, logits_mask
from modules.functional.voxelization import avg_voxelize

View file

@ -3,24 +3,28 @@ import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
_backend = load(name='_pvcnn_backend',
extra_cflags=['-O3', '-std=c++17'],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
sources=[os.path.join(_src_path,'src', f) for f in [
'ball_query/ball_query.cpp',
'ball_query/ball_query.cu',
'grouping/grouping.cpp',
'grouping/grouping.cu',
'interpolate/neighbor_interpolate.cpp',
'interpolate/neighbor_interpolate.cu',
'interpolate/trilinear_devox.cpp',
'interpolate/trilinear_devox.cu',
'sampling/sampling.cpp',
'sampling/sampling.cu',
'voxelization/vox.cpp',
'voxelization/vox.cu',
'bindings.cpp',
]]
)
_backend = load(
name="_pvcnn_backend",
extra_cflags=["-O3", "-std=c++17"],
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
sources=[
os.path.join(_src_path, "src", f)
for f in [
"ball_query/ball_query.cpp",
"ball_query/ball_query.cu",
"grouping/grouping.cpp",
"grouping/grouping.cu",
"interpolate/neighbor_interpolate.cpp",
"interpolate/neighbor_interpolate.cu",
"interpolate/trilinear_devox.cpp",
"interpolate/trilinear_devox.cu",
"sampling/sampling.cpp",
"sampling/sampling.cu",
"voxelization/vox.cpp",
"voxelization/vox.cu",
"bindings.cpp",
]
],
)
__all__ = ['_backend']
__all__ = ["_backend"]

View file

@ -1,19 +1,17 @@
from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['ball_query']
__all__ = ["ball_query"]
def ball_query(centers_coords, points_coords, radius, num_neighbors):
"""
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param radius: float, radius of ball query
:param num_neighbors: int, maximum number of neighbors
:return:
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
"""
centers_coords = centers_coords.contiguous()
points_coords = points_coords.contiguous()
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
"""
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param radius: float, radius of ball query
:param num_neighbors: int, maximum number of neighbors
:return:
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
"""
centers_coords = centers_coords.contiguous()
points_coords = points_coords.contiguous()
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['trilinear_devoxelize']
__all__ = ["trilinear_devoxelize"]
class TrilinearDevoxelization(Function):
@ -29,7 +29,7 @@ class TrilinearDevoxelization(Function):
@staticmethod
def backward(ctx, grad_output):
"""
:param ctx:
:param ctx:
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
:return:
gradient of inputs, FloatTensor[B, C, R, R, R]

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['grouping']
__all__ = ["grouping"]
class Grouping(Function):
@ -23,7 +23,7 @@ class Grouping(Function):
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['nearest_neighbor_interpolate']
__all__ = ["nearest_neighbor_interpolate"]
class NeighborInterpolation(Function):

View file

@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
__all__ = ['kl_loss', 'huber_loss']
__all__ = ["kl_loss", "huber_loss"]
def kl_loss(x, y):
@ -13,5 +13,5 @@ def kl_loss(x, y):
def huber_loss(error, delta):
abs_error = torch.abs(error)
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic)
return torch.mean(losses)

View file

@ -4,7 +4,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
__all__ = ["gather", "furthest_point_sample", "logits_mask"]
class Gather(Function):
@ -26,7 +26,7 @@ class Gather(Function):
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None
@ -60,11 +60,12 @@ def logits_mask(coords, logits, num_points_per_object):
mask: mask to select points, BoolTensor[B, N]
"""
batch_size, _, num_points = coords.shape
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
torch.ones_like(num_candidates)).float() # [B, C]
masked_coords_mean = (
torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, torch.ones_like(num_candidates)).float()
) # [B, C]
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
for i in range(batch_size):
current_mask = mask[i] # [N]
@ -74,10 +75,14 @@ def logits_mask(coords, logits, num_points_per_object):
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
selected_indices[i] = current_candidates[choices]
elif current_num_candidates > 0:
choices = np.concatenate([
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
])
choices = np.concatenate(
[
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
np.random.choice(
current_num_candidates, num_points_per_object % current_num_candidates, replace=False
),
]
)
np.random.shuffle(choices)
selected_indices[i] = current_candidates[choices]
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend
__all__ = ['avg_voxelize']
__all__ = ["avg_voxelize"]
class AvgVoxelization(Function):

View file

@ -2,7 +2,7 @@ import torch.nn as nn
import modules.functional as F
__all__ = ['KLLoss']
__all__ = ["KLLoss"]
class KLLoss(nn.Module):

View file

@ -5,7 +5,7 @@ import modules.functional as F
from modules.ball_query import BallQuery
from modules.shared_mlp import SharedMLP
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
__all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"]
class PointNetAModule(nn.Module):
@ -20,8 +20,9 @@ class PointNetAModule(nn.Module):
total_out_channels = 0
for _out_channels in out_channels:
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=1)
SharedMLP(
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1
)
)
total_out_channels += _out_channels[-1]
@ -43,7 +44,7 @@ class PointNetAModule(nn.Module):
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
def extra_repr(self):
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}"
class PointNetSAModule(nn.Module):
@ -67,8 +68,9 @@ class PointNetSAModule(nn.Module):
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
)
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=2)
SharedMLP(
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2
)
)
total_out_channels += _out_channels[-1]
@ -90,7 +92,7 @@ class PointNetSAModule(nn.Module):
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
def extra_repr(self):
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
return f"num_centers={self.num_centers}, out_channels={self.out_channels}"
class PointNetFPModule(nn.Module):
@ -107,7 +109,5 @@ class PointNetFPModule(nn.Module):
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
if points_features is not None:
interpolated_features = torch.cat(
[interpolated_features, points_features], dim=1
)
interpolated_features = torch.cat([interpolated_features, points_features], dim=1)
return self.mlp(interpolated_features), points_coords, interpolated_temb

View file

@ -1,16 +1,17 @@
import torch.nn as nn
import torch
import modules.functional as F
from modules.voxelization import Voxelization
from modules.shared_mlp import SharedMLP
from modules.se import SE3d
import torch.nn as nn
__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU']
import modules.functional as F
from modules.se import SE3d
from modules.shared_mlp import SharedMLP
from modules.voxelization import Voxelization
__all__ = ["PVConv", "Attention", "Swish", "PVConvReLU"]
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
def forward(self, x):
return x * torch.sigmoid(x)
class Attention(nn.Module):
@ -35,23 +36,19 @@ class Attention(nn.Module):
self.sm = nn.Softmax(-1)
def forward(self, x):
B, C = x.shape[:2]
h = x
q = self.q(h).reshape(B, C, -1)
k = self.k(h).reshape(B, C, -1)
v = self.v(h).reshape(B, C, -1)
q = self.q(h).reshape(B,C,-1)
k = self.k(h).reshape(B,C,-1)
v = self.v(h).reshape(B,C,-1)
qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5))
qk = torch.matmul(q.permute(0, 2, 1), k) # * (int(C) ** (-0.5))
w = self.sm(qk)
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:])
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B, C, *x.shape[2:])
h = self.out(h)
@ -61,9 +58,21 @@ class Attention(nn.Module):
return x
class PVConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False,
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
resolution,
attention=False,
dropout=0.1,
with_se=False,
with_se_relu=False,
normalize=True,
eps=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -74,13 +83,13 @@ class PVConv(nn.Module):
voxel_layers = [
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
Swish()
Swish(),
]
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
voxel_layers += [
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.GroupNorm(num_groups=8, num_channels=out_channels),
Attention(out_channels, 8) if attention else Swish()
Attention(out_channels, 8) if attention else Swish(),
]
if with_se:
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
@ -96,10 +105,21 @@ class PVConv(nn.Module):
return fused_features, coords, temb
class PVConvReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2,
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
resolution,
attention=False,
leak=0.2,
dropout=0.1,
with_se=False,
with_se_relu=False,
normalize=True,
eps=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@ -110,13 +130,13 @@ class PVConvReLU(nn.Module):
voxel_layers = [
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(leak, True)
nn.LeakyReLU(leak, True),
]
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
voxel_layers += [
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.BatchNorm3d(out_channels),
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True)
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True),
]
if with_se:
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))

View file

@ -1,18 +1,22 @@
import torch.nn as nn
import torch
__all__ = ['SE3d']
import torch.nn as nn
__all__ = ["SE3d"]
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
def forward(self, x):
return x * torch.sigmoid(x)
class SE3d(nn.Module):
def __init__(self, channel, reduction=8, use_relu=False):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(True) if use_relu else Swish() ,
nn.ReLU(True) if use_relu else Swish(),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
nn.Sigmoid(),
)
def forward(self, inputs):

View file

@ -1,12 +1,13 @@
import torch.nn as nn
import torch
import torch.nn as nn
__all__ = ['SharedMLP']
__all__ = ["SharedMLP"]
class Swish(nn.Module):
def forward(self,x):
return x * torch.sigmoid(x)
def forward(self, x):
return x * torch.sigmoid(x)
class SharedMLP(nn.Module):
def __init__(self, in_channels, out_channels, dim=1):
@ -23,11 +24,13 @@ class SharedMLP(nn.Module):
out_channels = [out_channels]
layers = []
for oc in out_channels:
layers.extend([
conv(in_channels, oc, 1),
bn(8, oc),
Swish(),
])
layers.extend(
[
conv(in_channels, oc, 1),
bn(8, oc),
Swish(),
]
)
in_channels = oc
self.layers = nn.Sequential(*layers)

View file

@ -3,7 +3,7 @@ import torch.nn as nn
import modules.functional as F
__all__ = ['Voxelization']
__all__ = ["Voxelization"]
class Voxelization(nn.Module):
@ -17,7 +17,10 @@ class Voxelization(nn.Module):
coords = coords.detach()
norm_coords = coords - coords.mean(2, keepdim=True)
if self.normalize:
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
norm_coords = (
norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps)
+ 0.5
)
else:
norm_coords = (norm_coords + 1) / 2.0
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
@ -25,4 +28,4 @@ class Voxelization(nn.Module):
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
def extra_repr(self):
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
return "resolution={}{}".format(self.r, ", normalized eps = {}".format(self.eps) if self.normalize else "")

View file

@ -1,26 +1,27 @@
import argparse
from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.file_utils import *
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
'''
from metrics.evaluation_metrics import EMD_CD, compute_all_metrics
from model.pvcnn_completion import PVCNN2Base
from utils.file_utils import *
"""
models
'''
"""
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
@ -31,21 +32,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
min_in = inv_stdv * (centered_x - 0.5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
x < 0.001,
log_cdf_plus,
torch.where(
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
),
)
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type
@ -54,15 +57,15 @@ class GaussianDiffusion:
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas = 1.0 - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
@ -70,21 +73,23 @@ class GaussianDiffusion:
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
)
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
@ -92,17 +97,15 @@ class GaussianDiffusion:
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
(bs,) = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
@ -114,54 +117,59 @@ class GaussianDiffusion:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
posterior_log_variance_clipped = self._extract(
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:, :, self.sv_points :]
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
"fixedlarge": (
self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
),
"fixedsmall": (
self.posterior_variance.to(data.device),
self.posterior_log_variance_clipped.to(data.device),
),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
if self.model_mean_type == "eps":
x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape
if return_pred_xstart:
@ -172,30 +180,31 @@ class GaussianDiffusion:
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
""" samples """
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
def p_sample_loop(
self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False
):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
@ -206,14 +215,21 @@ class GaussianDiffusion:
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False)
img_t = self.p_sample(
denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
)
assert img_t[:,:,self.sv_points:].shape == shape
assert img_t[:, :, self.sv_points :].shape == shape
return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
def p_sample_loop_trajectory(
self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False
):
"""
Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time
@ -223,31 +239,38 @@ class GaussianDiffusion:
"""
assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas)
total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t]
for t in reversed(range(0,total_steps)):
for t in reversed(range(0, total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False)
if t % freq == 0 or t == total_steps-1:
img_t = self.p_sample(
denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
)
if t % freq == 0 or t == total_steps - 1:
imgs.append(img_t)
assert imgs[-1].shape == shape
return imgs
'''losses'''
"""losses"""
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t
)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0)
return (kl, pred_xstart) if return_pred_xstart else kl
@ -259,66 +282,87 @@ class GaussianDiffusion:
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
noise = torch.randn(
data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device
)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise)
if self.loss_type == 'mse':
if self.loss_type == "mse":
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[
:, :, self.sv_points :
]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == "kl":
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
return_pred_xstart=False)
denoise_fn=denoise_fn,
data_start=data_start,
data_t=data_t,
t=t,
clip_denoised=False,
return_pred_xstart=False,
)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
'''debug'''
"""debug"""
def _prior_bpd(self, x_start):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
kl_prior = normal_kl(
mean1=qt_mean,
logvar1=qt_log_variance,
mean2=torch.tensor([0.0]).to(qt_mean),
logvar2=torch.tensor([0.0]).to(qt_log_variance),
)
assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
data_t = torch.cat(
[x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)],
dim=-1,
)
new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
clip_denoised=clip_denoised, return_pred_xstart=True)
denoise_fn,
data_start=x_start,
data_t=data_t,
t=t_b,
clip_denoised=clip_denoised,
return_pred_xstart=True,
)
# MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean(
dim=list(range(1, len(pred_xstart.shape)))
)
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
assert vals_bt_.shape == mse_bt_.shape == torch.Size(
[B, T]
) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
@ -336,39 +380,53 @@ class PVCNN2(PVCNN2Base):
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
def __init__(
self,
num_classes,
sv_points,
embed_dim,
use_att,
dropout,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
num_classes=num_classes,
sv_points=sv_points,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
self.model = PVCNN2(
num_classes=args.nc,
sv_points=args.svpoints,
embed_dim=args.embed_dim,
use_att=args.attention,
dropout=args.dropout,
extra_feature_channels=0,
)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
def _denoise(self, data, t):
B, D,N= data.shape
B, D, N = data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
@ -381,20 +439,22 @@ class Model(nn.Module):
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
return self.diffusion.p_sample_loop(
partial_x,
self._denoise,
shape=shape,
device=device,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running,
)
def train(self):
self.model.train()
@ -405,21 +465,19 @@ class Model(nn.Module):
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == "linear":
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == "warm0.1":
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
elif schedule_type == "warm0.2":
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
elif schedule_type == "warm0.5":
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
@ -428,22 +486,29 @@ def get_betas(schedule_type, b_start, b_end, time_num):
return betas
#############################################################################
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
categories=[category], split='train',
def get_mvr_dataset(pc_dataroot, views_root, npoints, category):
tr_dataset = ShapeNet15kPointClouds(
root_dir=pc_dataroot,
categories=[category],
split="train",
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
cache=os.path.join(pc_dataroot, '../cache'), split='val',
random_subsample=True,
)
te_dataset = ShapeNet_Multiview_Points(
root_pc=pc_dataroot,
root_views=views_root,
cache=os.path.join(pc_dataroot, "../cache"),
split="val",
categories=[category],
npoints=npoints, sv_samples=200,
npoints=npoints,
sv_samples=200,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
@ -451,39 +516,41 @@ def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
def evaluate_recon_mvr(opt, model, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
opt.npoints, opt.category)
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.category)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
ref = []
samples = []
masked = []
k = 0
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Reconstructing Samples"):
gt_all = data["test_points"]
x_all = data["sv_points"]
gt_all = data['test_points']
x_all = data['sv_points']
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
B, V, N, C = x_all.shape
gt_all = gt_all[:, None, :, :].expand(-1, V, -1, -1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
m, s = data['mean'].float(), data['std'].float()
m, s = data["mean"].float(), data["std"].float()
recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
clip_denoised=False).detach().cpu()
recon = (
model.gen_samples(
x[:, :, : opt.svpoints].cuda(), x[:, :, opt.svpoints :].shape, "cuda", clip_denoised=False
)
.detach()
.cpu()
)
recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B, V, N, C) * s + m
recon_adj = recon.reshape(B, V, N, C) * s + m
x_adj = x.reshape(B,V,N,C)* s + m
recon_adj = recon.reshape(B,V,N,C)* s + m
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
ref.append(gt_all * s + m)
masked.append(x_adj[:, :, : test_dataloader.dataset.sv_samples, :])
samples.append(recon_adj)
ref_pcs = torch.cat(ref, dim=0)
@ -492,31 +559,40 @@ def evaluate_recon_mvr(opt, model, save_dir, logger):
B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B, V, N, C), os.path.join(save_dir, "recon_gt.pth"))
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
torch.save(masked.reshape(B, V, *masked.shape[2:]), os.path.join(save_dir, "recon_masked.pth"))
# Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = EMD_CD(sample_pcs.reshape(B * V, N, C), ref_pcs.reshape(B * V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
results = {
ky: val.reshape(B, V)
if val.shape
== torch.Size(
[
B * V,
]
)
else val
for ky, val in results.items()
}
pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
results["pc"] = sample_pcs
torch.save(results, os.path.join(save_dir, "ours_results.pth"))
del ref_pcs, masked, results
def evaluate_saved(opt, saved_dir):
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = saved_dir + '/recon_gt.pth'
ours_pth = saved_dir + '/ours_results.pth'
gt = torch.load(gt_pth).permute(1,0,2,3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
gt_pth = saved_dir + "/recon_gt.pth"
ours_pth = saved_dir + "/ours_results.pth"
gt = torch.load(gt_pth).permute(1, 0, 2, 3)
ours = torch.load(ours_pth)["pc"].permute(1, 0, 2, 3)
all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
@ -534,7 +610,6 @@ def evaluate_saved(opt, saved_dir):
pprint({key: val.mean().item() for key, val in all_res.items()})
def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
@ -542,7 +617,7 @@ def main(opt):
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
@ -559,12 +634,10 @@ def main(opt):
model.eval()
with torch.no_grad():
logger.info("Resume Path:%s" % opt.model)
resumed_param = torch.load(opt.model)
model.load_state_dict(resumed_param['model_state'])
model.load_state_dict(resumed_param["model_state"])
if opt.eval_recon_mvr:
# Evaluate generation
@ -575,47 +648,44 @@ def main(opt):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--dataroot_sv', default='GenReData/')
parser.add_argument('--category', default='chair')
parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/")
parser.add_argument("--dataroot_sv", default="GenReData/")
parser.add_argument("--category", default="chair")
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
parser.add_argument("--workers", type=int, default=16, help="workers")
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
parser.add_argument('--eval_recon_mvr', default=True)
parser.add_argument('--eval_saved', default=True)
parser.add_argument("--eval_recon_mvr", default=True)
parser.add_argument("--eval_saved", default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
parser.add_argument('--svpoints', default=200)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
parser.add_argument("--nc", default=3)
parser.add_argument("--npoints", default=2048)
parser.add_argument("--svpoints", default=200)
"""model"""
parser.add_argument("--beta_start", default=0.0001)
parser.add_argument("--beta_end", default=0.02)
parser.add_argument("--schedule_type", default="linear")
parser.add_argument("--time_num", default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# params
parser.add_argument("--attention", default=True)
parser.add_argument("--dropout", default=0.1)
parser.add_argument("--embed_dim", type=int, default=64)
parser.add_argument("--loss_type", default="mse")
parser.add_argument("--model_mean_type", default="eps")
parser.add_argument("--model_var_type", default="fixedsmall")
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
parser.add_argument('--model', default='', required=True, help="path to model (to continue training)")
"""eval"""
'''eval'''
parser.add_argument("--eval_path", default="")
parser.add_argument('--eval_path',
default='')
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
opt = parser.parse_args()
@ -625,7 +695,9 @@ def parse_args():
opt.cuda = False
return opt
if __name__ == '__main__':
if __name__ == "__main__":
opt = parse_args()
main(opt)

View file

@ -1,31 +1,30 @@
import torch
import argparse
from pprint import pprint
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch
import torch.nn as nn
import torch.utils.data
import argparse
from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from metrics.evaluation_metrics import compute_all_metrics
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from model.pvcnn_generation import PVCNN2Base
from utils.file_utils import *
from utils.visualize import *
'''
"""
models
'''
"""
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1]
@ -36,37 +35,40 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5)
min_in = inv_stdv * (centered_x - 0.5)
cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = torch.where(
x < 0.001, log_cdf_plus,
torch.where(x > 0.999, log_one_minus_cdf_min,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
x < 0.001,
log_cdf_plus,
torch.where(
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
),
)
assert log_probs.shape == x.shape
return log_probs
class GaussianDiffusion:
def __init__(self,betas, loss_type, model_mean_type, model_var_type):
def __init__(self, betas, loss_type, model_mean_type, model_var_type):
self.loss_type = loss_type
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
# initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas
alphas = 1.0 - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float()
@ -74,21 +76,23 @@ class GaussianDiffusion:
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
self.posterior_log_variance_clipped = torch.log(
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
)
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
@staticmethod
def _extract(a, t, x_shape):
@ -96,17 +100,15 @@ class GaussianDiffusion:
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
(bs,) = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance
@ -118,56 +120,62 @@ class GaussianDiffusion:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
posterior_log_variance_clipped = self._extract(
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
"fixedlarge": (
self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
),
"fixedsmall": (
self.posterior_variance.to(data.device),
self.posterior_log_variance_clipped.to(data.device),
),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
if self.model_mean_type == "eps":
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
if clip_denoised:
x_recon = torch.clamp(x_recon, -.5, .5)
x_recon = torch.clamp(x_recon, -0.5, 0.5)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape == data.shape
assert model_variance.shape == model_log_variance.shape == data.shape
if return_pred_xstart:
@ -178,18 +186,19 @@ class GaussianDiffusion:
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
)
''' samples '''
""" samples """
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
return_pred_xstart=True)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
assert noise.shape == data.shape
# no noise when t == 0
@ -201,10 +210,17 @@ class GaussianDiffusion:
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, denoise_fn, shape, device,
noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=True, max_timestep=None, keep_running=False):
def p_sample_loop(
self,
denoise_fn,
shape,
device,
noise_fn=torch.randn,
constrain_fn=lambda x, t: x,
clip_denoised=True,
max_timestep=None,
keep_running=False,
):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
@ -220,28 +236,38 @@ class GaussianDiffusion:
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
img_t = constrain_fn(img_t, t)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
clip_denoised=clip_denoised, return_pred_xstart=False).detach()
img_t = self.p_sample(
denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
).detach()
assert img_t.shape == shape
return img_t
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t: x):
assert t >= 1
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t - 1)
encoding = self.q_sample(x0, t_vec)
img_t = encoding
for k in reversed(range(0,t)):
for k in reversed(range(0, t)):
img_t = constrain_fn(img_t, k)
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
img_t = self.p_sample(
denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=False,
return_pred_xstart=False,
use_var=True,
).detach()
return img_t
@ -260,40 +286,50 @@ class PVCNN2(PVCNN2Base):
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
voxel_resolution_multiplier=1):
def __init__(
self,
num_classes,
embed_dim,
use_att,
dropout,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
num_classes=num_classes,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
)
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
dropout=args.dropout, extra_feature_channels=0)
self.model = PVCNN2(
num_classes=args.nc,
embed_dim=args.embed_dim,
use_att=args.attention,
dropout=args.dropout,
extra_feature_channels=0,
)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
def _denoise(self, data, t):
B, D,N= data.shape
B, D, N = data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
@ -307,23 +343,34 @@ class Model(nn.Module):
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
clip_denoised=False, max_timestep=None,
keep_running=False):
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
constrain_fn=constrain_fn,
clip_denoised=clip_denoised, max_timestep=max_timestep,
keep_running=keep_running)
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
def gen_samples(
self,
shape,
device,
noise_fn=torch.randn,
constrain_fn=lambda x, t: x,
clip_denoised=False,
max_timestep=None,
keep_running=False,
):
return self.diffusion.p_sample_loop(
self._denoise,
shape=shape,
device=device,
noise_fn=noise_fn,
constrain_fn=constrain_fn,
clip_denoised=clip_denoised,
max_timestep=max_timestep,
keep_running=keep_running,
)
def reconstruct(self, x0, t, constrain_fn=lambda x, t: x):
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
def train(self):
@ -337,20 +384,17 @@ class Model(nn.Module):
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
if schedule_type == "linear":
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
elif schedule_type == "warm0.1":
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2':
elif schedule_type == "warm0.2":
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5':
elif schedule_type == "warm0.5":
betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
@ -358,111 +402,109 @@ def get_betas(schedule_type, b_start, b_end, time_num):
raise NotImplementedError(schedule_type)
return betas
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
'''
"""
:param target_shape_constraint: target voxels
:return: constrained x
'''
"""
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
def constrain_fn(x, t):
eps_ = eps_all[t] if (t<1000) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000) ** 2))
def constrain_fn(x, t):
eps_ = eps_all[t] if (t < 1000) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
return x
return constrain_fn
#############################################################################
def get_dataset(dataroot, npoints,category,use_mask=False):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=[category], split='train',
def get_dataset(dataroot, npoints, category, use_mask=False):
tr_dataset = ShapeNet15kPointClouds(
root_dir=dataroot,
categories=[category],
split="train",
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True, use_mask = use_mask)
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=[category], split='val',
random_subsample=True,
use_mask=use_mask,
)
te_dataset = ShapeNet15kPointClouds(
root_dir=dataroot,
categories=[category],
split="val",
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
use_mask=use_mask
use_mask=use_mask,
)
return tr_dataset, te_dataset
def evaluate_gen(opt, ref_pcs, logger):
if ref_pcs is None:
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
ref = []
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points']
m, s = data['mean'].float(), data['std'].float()
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
x = data["test_points"]
m, s = data["mean"].float(), data["std"].float()
ref.append(x*s + m)
ref.append(x * s + m)
ref_pcs = torch.cat(ref, dim=0).contiguous()
logger.info("Loading sample path: %s"
% (opt.eval_path))
logger.info("Loading sample path: %s" % (opt.eval_path))
sample_pcs = torch.load(opt.eval_path).contiguous()
logger.info("Generation sample size:%s reference size: %s"
% (sample_pcs.size(), ref_pcs.size()))
logger.info("Generation sample size:%s reference size: %s" % (sample_pcs.size(), ref_pcs.size()))
# Compute metrics
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
results = {k: (v.cpu().detach().item()
if not isinstance(v, float) else v) for k, v in results.items()}
results = {k: (v.cpu().detach().item() if not isinstance(v, float) else v) for k, v in results.items()}
pprint(results)
logger.info(results)
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
pprint('JSD: {}'.format(jsd))
logger.info('JSD: {}'.format(jsd))
pprint("JSD: {}".format(jsd))
logger.info("JSD: {}".format(jsd))
def generate(model, opt):
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), drop_last=False)
test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
with torch.no_grad():
samples = []
ref = []
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
x = data['test_points'].transpose(1,2)
m, s = data['mean'].float(), data['std'].float()
gen = model.gen_samples(x.shape,
'cuda', clip_denoised=False).detach().cpu()
gen = gen.transpose(1,2).contiguous()
x = x.transpose(1,2).contiguous()
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Generating Samples"):
x = data["test_points"].transpose(1, 2)
m, s = data["mean"].float(), data["std"].float()
gen = model.gen_samples(x.shape, "cuda", clip_denoised=False).detach().cpu()
gen = gen.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
gen = gen * s + m
x = x * s + m
@ -482,20 +524,20 @@ def generate(model, opt):
# 1,
# 0.5,
# )
# visualize using matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.cm as cm
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")
for idx, pc in enumerate(gen[:64]):
print(f"Visualizing point cloud {idx}...")
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,2], cmap=cm.jet)
ax.set_aspect('equal')
ax.axis('off')
ax = fig.add_subplot(111, projection="3d")
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
ax.set_aspect("equal")
ax.axis("off")
# ax.set_xlim(-1, 1)
# ax.set_ylim(-1, 1)
# ax.set_zlim(-1, 1)
@ -507,17 +549,14 @@ def generate(model, opt):
torch.save(samples, opt.eval_path)
return ref
def main(opt):
if opt.category == 'airplane':
if opt.category == "airplane":
opt.beta_start = 1e-5
opt.beta_end = 0.008
opt.schedule_type = 'warm0.1'
opt.schedule_type = "warm0.1"
exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__)
@ -525,7 +564,7 @@ def main(opt):
copy_source(__file__, output_dir)
logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn')
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
@ -542,64 +581,59 @@ def main(opt):
model.eval()
with torch.no_grad():
logger.info("Resume Path:%s" % opt.model)
resumed_param = torch.load(opt.model)
model.load_state_dict(resumed_param['model_state'])
model.load_state_dict(resumed_param["model_state"])
ref = None
if opt.generate:
opt.eval_path = os.path.join(outf_syn, 'samples.pth')
opt.eval_path = os.path.join(outf_syn, "samples.pth")
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
ref=generate(model, opt)
ref = generate(model, opt)
if opt.eval_gen:
# Evaluate generation
evaluate_gen(opt, ref, logger)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
parser.add_argument('--category', default='chair')
parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/")
parser.add_argument("--category", default="chair")
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
parser.add_argument('--workers', type=int, default=16, help='workers')
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
parser.add_argument("--workers", type=int, default=16, help="workers")
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
parser.add_argument('--generate',default=True)
parser.add_argument('--eval_gen', default=True)
parser.add_argument("--generate", default=True)
parser.add_argument("--eval_gen", default=True)
parser.add_argument('--nc', default=3)
parser.add_argument('--npoints', default=2048)
'''model'''
parser.add_argument('--beta_start', default=0.0001)
parser.add_argument('--beta_end', default=0.02)
parser.add_argument('--schedule_type', default='linear')
parser.add_argument('--time_num', default=1000)
parser.add_argument("--nc", default=3)
parser.add_argument("--npoints", default=2048)
"""model"""
parser.add_argument("--beta_start", default=0.0001)
parser.add_argument("--beta_end", default=0.02)
parser.add_argument("--schedule_type", default="linear")
parser.add_argument("--time_num", default=1000)
#params
parser.add_argument('--attention', default=True)
parser.add_argument('--dropout', default=0.1)
parser.add_argument('--embed_dim', type=int, default=64)
parser.add_argument('--loss_type', default='mse')
parser.add_argument('--model_mean_type', default='eps')
parser.add_argument('--model_var_type', default='fixedsmall')
# params
parser.add_argument("--attention", default=True)
parser.add_argument("--dropout", default=0.1)
parser.add_argument("--embed_dim", type=int, default=64)
parser.add_argument("--loss_type", default="mse")
parser.add_argument("--model_mean_type", default="eps")
parser.add_argument("--model_var_type", default="fixedsmall")
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
parser.add_argument('--model', default='',required=True, help="path to model (to continue training)")
"""eval"""
'''eval'''
parser.add_argument("--eval_path", default="")
parser.add_argument('--eval_path',
default='')
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
opt = parser.parse_args()
@ -609,7 +643,9 @@ def parse_args():
opt.cuda = False
return opt
if __name__ == '__main__':
if __name__ == "__main__":
opt = parse_args()
set_seed(opt)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,35 +1,33 @@
import datetime
import logging
import os
import random
import sys
from shutil import copyfile
import datetime
import torch
import logging
logger = logging.getLogger()
import numpy as np
def set_global_gpu_env(opt):
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
torch.cuda.set_device(opt.gpu)
def copy_source(file, output_dir):
copyfile(file, os.path.join(output_dir, os.path.basename(file)))
def setup_logging(output_dir):
log_format = logging.Formatter("%(asctime)s : %(message)s")
logger = logging.getLogger()
logger.handlers = []
output_file = os.path.join(output_dir, 'output.log')
output_file = os.path.join(output_dir, "output.log")
file_handler = logging.FileHandler(output_file)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
@ -44,16 +42,14 @@ def setup_logging(output_dir):
def get_output_dir(prefix, exp_id):
t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
output_dir = os.path.join(prefix, 'output/' + exp_id, t)
t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
output_dir = os.path.join(prefix, "output/" + exp_id, t)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
return output_dir
def set_seed(opt):
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
@ -65,8 +61,8 @@ def set_seed(opt):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def setup_output_subdirs(output_dir, *subfolders):
def setup_output_subdirs(output_dir, *subfolders):
output_subdirs = output_dir
try:
os.makedirs(output_subdirs)
@ -82,4 +78,4 @@ def setup_output_subdirs(output_dir, *subfolders):
pass
subfolder_list.append(curr_subf)
return subfolder_list
return subfolder_list

View file

@ -1,20 +1,22 @@
import numpy as np
import warnings
import numpy as np
from scipy.stats import entropy
def iterate_in_chunks(l, n):
'''Yield successive 'n'-sized chunks from iterable 'l'.
"""Yield successive 'n'-sized chunks from iterable 'l'.
Note: last chunk will be smaller than l if n doesn't divide l perfectly.
'''
"""
for i in range(0, len(l), n):
yield l[i:i + n]
yield l[i : i + n]
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
that is placed in the unit-cube.
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
'''
"""
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
spacing = 1.0 / float(resolution - 1)
for i in range(resolution):
@ -30,9 +32,11 @@ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
return grid, spacing
def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False,
use_EMD=False):
'''Computes the MMD between two sets of point-clouds.
def minimum_mathing_distance(
sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False
):
"""Computes the MMD between two sets of point-clouds.
Args:
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and
compared to a set of "reference" point-clouds.
@ -49,17 +53,17 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
use_EMD (boolean: If true, the matchings are based on the EMD.
Returns:
A tuple containing the MMD and all the matched distances of which the MMD is their mean.
'''
"""
n_ref, n_pc_points, pc_dim = ref_pcs.shape
_, n_pc_points_s, pc_dim_s = sample_pcs.shape
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
raise ValueError('Incompatible size of point-clouds.')
raise ValueError("Incompatible size of point-clouds.")
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize,
sess=sess, use_sqrt=use_sqrt,
use_EMD=use_EMD)
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(
n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
)
matched_dists = []
for i in range(n_ref):
best_in_all_batches = []
@ -75,9 +79,18 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
return mmd, matched_dists
def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False,
ret_dist=False):
'''Computes the Coverage between two sets of point-clouds.
def coverage(
sample_pcs,
ref_pcs,
batch_size,
normalize=True,
sess=None,
verbose=False,
use_sqrt=False,
use_EMD=False,
ret_dist=False,
):
"""Computes the Coverage between two sets of point-clouds.
Args:
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched
and compared to a set of "reference" point-clouds.
@ -97,18 +110,16 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
Returns: the coverage score (int),
the indices of the ref_pcs that are matched with each sample_pc
and optionally the matched distances of the samples_pcs.
'''
"""
n_ref, n_pc_points, pc_dim = ref_pcs.shape
n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
raise ValueError('Incompatible Point-Clouds.')
raise ValueError("Incompatible Point-Clouds.")
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points,
normalize=normalize,
sess=sess,
use_sqrt=use_sqrt,
use_EMD=use_EMD)
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(
n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
)
matched_gt = []
matched_dist = []
for i in xrange(n_sam):
@ -140,12 +151,12 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
"""Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
Args:
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
resolution: (int) grid-resolution. Affects granularity of measurements.
'''
"""
in_unit_sphere = True
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
@ -153,19 +164,19 @@ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
'''Given a collection of point-clouds, estimate the entropy of the random variables
"""Given a collection of point-clouds, estimate the entropy of the random variables
corresponding to occupancy-grid activation patterns.
Inputs:
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
grid_resolution (int) size of occupancy grid that will be used.
'''
"""
epsilon = 10e-4
bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
warnings.warn('Point-clouds are not in unit cube.')
warnings.warn("Point-clouds are not in unit cube.")
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
warnings.warn('Point-clouds are not in unit sphere.')
if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
warnings.warn("Point-clouds are not in unit sphere.")
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3)
@ -192,13 +203,14 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
return acc_entropy / len(grid_counters), grid_counters
def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.')
raise ValueError("Negative values.")
if len(P) != len(Q):
raise ValueError('Non equal size.')
raise ValueError("Non equal size.")
P_ = P / np.sum(P) # Ensure probabilities.
P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q)
e1 = entropy(P_, base=2)
@ -209,13 +221,14 @@ def jensen_shannon_divergence(P, Q):
res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.')
warnings.warn("Numerical values of two JSD methods don't agree.")
return res
def _jsdiv(P, Q):
'''another way of computing JSD'''
"""another way of computing JSD"""
def _kldiv(A, B):
a = A.copy()
b = B.copy()
@ -229,4 +242,4 @@ def _jsdiv(P, Q):
M = 0.5 * (P_ + Q_)
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))

View file

@ -1,32 +1,33 @@
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
matplotlib.use("agg")
import os
import trimesh
from pathlib import Path
'''
import matplotlib.pyplot as plt
import numpy as np
import trimesh
from mpl_toolkits.mplot3d import Axes3D
"""
Custom visualization
'''
"""
def export_to_pc_batch(dir, pcs, colors=None):
Path(dir).mkdir(parents=True, exist_ok=True)
for i, xyz in enumerate(pcs):
if colors is None:
color = None
else:
color = colors[i]
pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color)
pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color)
def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
'''
def export_to_obj(dir, meshes, transform=lambda v, f: (v, f)):
"""
transform: f(vertices, faces) --> transformed (vertices, faces)
'''
"""
Path(dir).mkdir(parents=True, exist_ok=True)
for i, data in enumerate(meshes):
v, f = transform(data[0], data[1])
@ -36,14 +37,15 @@ def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
v_color = None
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
out = trimesh.exchange.obj.export_obj(mesh)
with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f:
with open(os.path.join(dir, "sample_" + str(i) + ".obj"), "w") as f:
f.write(out)
f.close()
def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
'''
def export_to_obj_single(path, data, transform=lambda v, f: (v, f)):
"""
transform: f(vertices, faces) --> transformed (vertices, faces)
'''
"""
v, f = transform(data[0], data[1])
if len(data) > 2:
v_color = data[2]
@ -51,15 +53,15 @@ def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
v_color = None
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
out = trimesh.exchange.obj.export_obj(mesh)
with open(path, 'w') as f:
with open(path, "w") as f:
f.write(out)
f.close()
def meshwrite(filename, verts, faces, norms, colors):
"""Save a 3D mesh to a polygon .ply file.
"""
"""Save a 3D mesh to a polygon .ply file."""
# Write header
ply_file = open(filename, 'w')
ply_file = open(filename, "w")
ply_file.write("ply\n")
ply_file.write("format ascii 1.0\n")
ply_file.write("element vertex %d\n" % (verts.shape[0]))
@ -78,11 +80,20 @@ def meshwrite(filename, verts, faces, norms, colors):
# Write vertex list
for i in range(verts.shape[0]):
ply_file.write("%f %f %f %f %f %f %d %d %d\n" % (
verts[i, 0], verts[i, 1], verts[i, 2],
norms[i, 0], norms[i, 1], norms[i, 2],
colors[i, 0], colors[i, 1], colors[i, 2],
))
ply_file.write(
"%f %f %f %f %f %f %d %d %d\n"
% (
verts[i, 0],
verts[i, 1],
verts[i, 2],
norms[i, 0],
norms[i, 1],
norms[i, 2],
colors[i, 0],
colors[i, 1],
colors[i, 2],
)
)
# Write face list
for i in range(faces.shape[0]):
@ -92,14 +103,13 @@ def meshwrite(filename, verts, faces, norms, colors):
def pcwrite(filename, xyz, rgb=None):
"""Save a point cloud to a polygon .ply file.
"""
"""Save a point cloud to a polygon .ply file."""
if rgb is None:
rgb = np.ones_like(xyz) * 128
rgb = rgb.astype(np.uint8)
# Write header
ply_file = open(filename, 'w')
ply_file = open(filename, "w")
ply_file.write("ply\n")
ply_file.write("format ascii 1.0\n")
ply_file.write("element vertex %d\n" % (xyz.shape[0]))
@ -113,60 +123,67 @@ def pcwrite(filename, xyz, rgb=None):
# Write vertex list
for i in range(xyz.shape[0]):
ply_file.write("%f %f %f %d %d %d\n" % (
xyz[i, 0], xyz[i, 1], xyz[i, 2],
rgb[i, 0], rgb[i, 1], rgb[i, 2],
))
ply_file.write(
"%f %f %f %d %d %d\n"
% (
xyz[i, 0],
xyz[i, 1],
xyz[i, 2],
rgb[i, 0],
rgb[i, 1],
rgb[i, 2],
)
)
'''
"""
Matplotlib Visualization
'''
"""
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
r''' Visualizes voxel data.
r"""Visualizes voxel data.
show only first num_shown
'''
batch_size =voxels.shape[0]
"""
batch_size = voxels.shape[0]
voxels = voxels.squeeze(1) > threshold
num_shown = min(num_shown, batch_size)
n = int(np.sqrt(num_shown))
fig = plt.figure(figsize=(20,20))
fig = plt.figure(figsize=(20, 20))
for idx, pc in enumerate(voxels[:num_shown]):
if idx >= n*n:
if idx >= n * n:
break
pc = voxels[idx]
ax = fig.add_subplot(n, n, idx + 1, projection='3d')
ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5)
ax = fig.add_subplot(n, n, idx + 1, projection="3d")
ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5)
ax.view_init()
ax.axis('off')
plt.savefig(out_file, bbox_inches='tight')
ax.axis("off")
plt.savefig(out_file, bbox_inches="tight")
plt.close()
def visualize_pointcloud(points, normals=None,
out_file=None, show=False, elev=30, azim=225):
r''' Visualizes point cloud data.
def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225):
r"""Visualizes point cloud data.
Args:
points (tensor): point data
normals (tensor): normal data (if existing)
out_file (string): output file
show (bool): whether the plot should be shown
'''
"""
# Create plot
fig = plt.figure()
ax = fig.gca(projection=Axes3D.name)
ax.scatter(points[:, 2], points[:, 0], points[:, 1])
if normals is not None:
ax.quiver(
points[:, 2], points[:, 0], points[:, 1],
normals[:, 2], normals[:, 0], normals[:, 1],
length=0.1, color='k'
points[:, 2], points[:, 0], points[:, 1], normals[:, 2], normals[:, 0], normals[:, 1], length=0.1, color="k"
)
ax.set_xlabel('Z')
ax.set_ylabel('X')
ax.set_zlabel('Y')
ax.set_xlabel("Z")
ax.set_ylabel("X")
ax.set_zlabel("Y")
# ax.set_xlim(-0.5, 0.5)
# ax.set_ylim(-0.5, 0.5)
# ax.set_zlim(-0.5, 0.5)
@ -178,37 +195,39 @@ def visualize_pointcloud(points, normals=None,
plt.close(fig)
def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225):
def visualize_pointcloud_batch(
path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225
):
batch_size = len(pointclouds)
fig = plt.figure(figsize=(20,20))
fig = plt.figure(figsize=(20, 20))
ncols = int(np.sqrt(batch_size))
nrows = max(1, (batch_size-1) // ncols+1)
nrows = max(1, (batch_size - 1) // ncols + 1)
for idx, pc in enumerate(pointclouds):
if vis_label:
label = categories[labels[idx].item()]
pred = categories[pred_labels[idx]]
colour = 'g' if label == pred else 'r'
colour = "g" if label == pred else "r"
elif target is None:
colour = 'g'
colour = "g"
else:
colour = target[idx]
pc = pc.cpu().numpy()
ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d')
ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d")
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
ax.view_init(elev=elev, azim=azim)
ax.axis('off')
ax.axis("off")
if vis_label:
ax.set_title('GT: {0}\nPred: {1}'.format(label, pred))
ax.set_title("GT: {0}\nPred: {1}".format(label, pred))
plt.savefig(path)
plt.close(fig)
'''
"""
Plot stats
'''
"""
def plot_stats(output_dir, stats, interval):
content = stats.keys()
@ -218,5 +237,5 @@ def plot_stats(output_dir, stats, interval):
axs[j].plot(interval, v)
axs[j].set_ylabel(k)
f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight')
f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight")
plt.close(f)