style: autoformatting
This commit is contained in:
parent
d887d74852
commit
2fbfc320f2
|
@ -1,32 +1,33 @@
|
|||
|
||||
from glob import glob
|
||||
import re
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def raw_camparam_from_xml(path, pose="lookAt"):
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
tree = ET.parse(path)
|
||||
elm = tree.find("./sensor/transform/" + pose)
|
||||
camparam = elm.attrib
|
||||
origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
|
||||
target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
|
||||
up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
|
||||
height = int(
|
||||
tree.find("./sensor/film/integer[@name='height']").attrib['value'])
|
||||
width = int(
|
||||
tree.find("./sensor/film/integer[@name='width']").attrib['value'])
|
||||
origin = np.fromstring(camparam["origin"], dtype=np.float32, sep=",")
|
||||
target = np.fromstring(camparam["target"], dtype=np.float32, sep=",")
|
||||
up = np.fromstring(camparam["up"], dtype=np.float32, sep=",")
|
||||
height = int(tree.find("./sensor/film/integer[@name='height']").attrib["value"])
|
||||
width = int(tree.find("./sensor/film/integer[@name='width']").attrib["value"])
|
||||
|
||||
camparam = dict()
|
||||
camparam['origin'] = origin
|
||||
camparam['up'] = up
|
||||
camparam['target'] = target
|
||||
camparam['height'] = height
|
||||
camparam['width'] = width
|
||||
camparam["origin"] = origin
|
||||
camparam["up"] = up
|
||||
camparam["target"] = target
|
||||
camparam["height"] = height
|
||||
camparam["width"] = width
|
||||
return camparam
|
||||
|
||||
|
||||
def get_cam_pos(origin, target, up):
|
||||
inward = origin - target
|
||||
right = np.cross(up, inward)
|
||||
|
@ -38,59 +39,54 @@ def get_cam_pos(origin, target, up):
|
|||
ry /= np.linalg.norm(ry)
|
||||
rz /= np.linalg.norm(rz)
|
||||
|
||||
rot = np.stack([
|
||||
rx,
|
||||
ry,
|
||||
-rz
|
||||
], axis=0)
|
||||
|
||||
|
||||
aff = np.concatenate([
|
||||
np.eye(3), -origin[:,None]
|
||||
], axis=1)
|
||||
rot = np.stack([rx, ry, -rz], axis=0)
|
||||
|
||||
aff = np.concatenate([np.eye(3), -origin[:, None]], axis=1)
|
||||
|
||||
ext = np.matmul(rot, aff)
|
||||
|
||||
result = np.concatenate(
|
||||
[ext, np.array([[0,0,0,1]])], axis=0
|
||||
)
|
||||
|
||||
|
||||
result = np.concatenate([ext, np.array([[0, 0, 0, 1]])], axis=0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
|
||||
depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png')))
|
||||
cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths]
|
||||
|
||||
depths = sorted(glob(os.path.join(datapoint_dir, "*depth.png")))
|
||||
cam_ext = [
|
||||
"_".join(re.sub(dataroot.strip("/"), camera_param_dir.strip("/"), f).split("_")[:-1]) + ".xml" for f in depths
|
||||
]
|
||||
|
||||
for i, (f, pth) in enumerate(zip(cam_ext, depths)):
|
||||
if not os.path.exists(f):
|
||||
continue
|
||||
params=raw_camparam_from_xml(f)
|
||||
origin, target, up, width, height = params['origin'], params['target'], params['up'],\
|
||||
params['width'], params['height']
|
||||
params = raw_camparam_from_xml(f)
|
||||
origin, target, up, width, height = (
|
||||
params["origin"],
|
||||
params["target"],
|
||||
params["up"],
|
||||
params["width"],
|
||||
params["height"],
|
||||
)
|
||||
|
||||
ext_matrix = get_cam_pos(origin, target, up)
|
||||
|
||||
#####
|
||||
diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5
|
||||
diag = (0.036**2 + 0.024**2) ** 0.5
|
||||
focal_length = 0.05
|
||||
res = [480, 480]
|
||||
h_relative = (res[1] / res[0])
|
||||
sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2))
|
||||
h_relative = res[1] / res[0]
|
||||
sensor_width = np.sqrt(diag**2 / (1 + h_relative**2))
|
||||
pix_size = sensor_width / res[0]
|
||||
|
||||
K = np.array([
|
||||
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
|
||||
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
|
||||
[0, 0, 1]
|
||||
])
|
||||
K = np.array(
|
||||
[
|
||||
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
|
||||
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
|
||||
[0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K)
|
||||
np.savez(pth.split("depth.png")[0] + "cam_params.npz", extr=ext_matrix, intr=K)
|
||||
|
||||
|
||||
def main(opt):
|
||||
|
@ -102,21 +98,16 @@ def main(opt):
|
|||
if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
|
||||
leaf_subdirs.append(dirpath)
|
||||
|
||||
|
||||
|
||||
for k, dir_ in enumerate(leaf_subdirs):
|
||||
print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_))
|
||||
print("Processing dir {}/{}: {}".format(k, len(leaf_subdirs), dir_))
|
||||
|
||||
convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument('--dataroot', type=str, default='GenReData/')
|
||||
args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2')
|
||||
args.add_argument("--dataroot", type=str, default="GenReData/")
|
||||
args.add_argument("--mitsuba_xml_root", type=str, default="GenReData/genre-xml_v2")
|
||||
|
||||
opt = args.parse_args()
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
from plyfile import PlyData, PlyElement
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def project_pc_to_image(points, resolution=64):
|
||||
"""project point clouds into 2D image
|
||||
|
@ -26,29 +28,32 @@ def project_pc_to_image(points, resolution=64):
|
|||
|
||||
|
||||
def write_ply(points, filename, text=False):
|
||||
""" input: Nx3, write points to filename as PLY format. """
|
||||
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
|
||||
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
|
||||
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
|
||||
with open(filename, mode='wb') as f:
|
||||
"""input: Nx3, write points to filename as PLY format."""
|
||||
points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
|
||||
vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
|
||||
el = PlyElement.describe(vertex, "vertex", comments=["vertices"])
|
||||
with open(filename, mode="wb") as f:
|
||||
PlyData([el], text=text).write(f)
|
||||
|
||||
|
||||
def rotate_point_cloud(points, transformation_mat):
|
||||
|
||||
new_points = np.dot(transformation_mat, points.T).T
|
||||
|
||||
return new_points
|
||||
|
||||
|
||||
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
|
||||
""" align 3depn shapes to shapenet coordinates"""
|
||||
"""align 3depn shapes to shapenet coordinates"""
|
||||
# angle = math.radians(angle_deg)
|
||||
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
|
||||
# rot_m = rot_m.to_matrix()
|
||||
rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00],
|
||||
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
|
||||
[-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]])
|
||||
rot_m = np.array(
|
||||
[
|
||||
[2.22044605e-16, 0.00000000e00, 1.00000000e00],
|
||||
[0.00000000e00, 1.00000000e00, 0.00000000e00],
|
||||
[-1.00000000e00, 0.00000000e00, 2.22044605e-16],
|
||||
]
|
||||
)
|
||||
|
||||
new_points = rotate_point_cloud(points, rot_m)
|
||||
|
||||
|
@ -87,14 +92,13 @@ def sample_point_cloud_by_n(points, n_pts):
|
|||
return points
|
||||
|
||||
|
||||
|
||||
def collect_data_id(split_dir, classname, phase):
|
||||
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
|
||||
if not os.path.exists(filename):
|
||||
raise ValueError("Invalid filepath: {}".format(filename))
|
||||
|
||||
all_ids = []
|
||||
with open(filename, 'r') as fp:
|
||||
with open(filename, "r") as fp:
|
||||
info = json.load(fp)
|
||||
for item in info:
|
||||
all_ids.append(item["anno_id"])
|
||||
|
@ -102,7 +106,6 @@ def collect_data_id(split_dir, classname, phase):
|
|||
return all_ids
|
||||
|
||||
|
||||
|
||||
class GANdatasetPartNet(Dataset):
|
||||
def __init__(self, phase, data_root, category, n_pts):
|
||||
super(GANdatasetPartNet, self).__init__()
|
||||
|
@ -114,10 +117,12 @@ class GANdatasetPartNet(Dataset):
|
|||
|
||||
self.data_root = data_root
|
||||
|
||||
shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase)
|
||||
shape_names = collect_data_id(
|
||||
os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase
|
||||
)
|
||||
self.shape_names = []
|
||||
for name in shape_names:
|
||||
path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name)
|
||||
path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name)
|
||||
if os.path.exists(path):
|
||||
self.shape_names.append(name)
|
||||
|
||||
|
@ -129,12 +134,12 @@ class GANdatasetPartNet(Dataset):
|
|||
@staticmethod
|
||||
def load_point_cloud(path):
|
||||
pc = trimesh.load(path)
|
||||
pc = pc.vertices / 2.0 # scale to unit sphere
|
||||
pc = pc.vertices / 2.0 # scale to unit sphere
|
||||
return pc
|
||||
|
||||
@staticmethod
|
||||
def read_point_cloud_part_label(path):
|
||||
with open(path, 'r') as fp:
|
||||
with open(path, "r") as fp:
|
||||
labels = fp.readlines()
|
||||
labels = np.array([int(x) for x in labels])
|
||||
return labels
|
||||
|
@ -156,26 +161,31 @@ class GANdatasetPartNet(Dataset):
|
|||
|
||||
def __getitem__(self, index):
|
||||
raw_shape_name = self.shape_names[index]
|
||||
raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply')
|
||||
raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply")
|
||||
raw_pc = self.load_point_cloud(raw_ply_path)
|
||||
|
||||
raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt')
|
||||
raw_label_path = os.path.join(
|
||||
self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt"
|
||||
)
|
||||
part_labels = self.read_point_cloud_part_label(raw_label_path)
|
||||
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
|
||||
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
|
||||
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
|
||||
|
||||
real_shape_name = self.shape_names[index]
|
||||
real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply')
|
||||
real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply")
|
||||
real_pc = self.load_point_cloud(real_ply_path)
|
||||
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
|
||||
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
|
||||
|
||||
return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name,
|
||||
'n_part_keep': n_part_keep, 'idx': index}
|
||||
return {
|
||||
"raw": raw_pc,
|
||||
"real": real_pc,
|
||||
"raw_id": raw_shape_name,
|
||||
"real_id": real_shape_name,
|
||||
"n_part_keep": n_part_keep,
|
||||
"idx": index,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.shape_names)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,33 +1,67 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils import data
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
|
||||
synsetid_to_cate = {
|
||||
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
|
||||
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
|
||||
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
|
||||
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
|
||||
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
|
||||
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
|
||||
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
|
||||
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
|
||||
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
|
||||
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
|
||||
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
|
||||
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
|
||||
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
|
||||
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
|
||||
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
|
||||
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
|
||||
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
|
||||
'04554684': 'washer', '02992529': 'cellphone',
|
||||
'02843684': 'birdhouse', '02871439': 'bookshelf',
|
||||
"02691156": "airplane",
|
||||
"02773838": "bag",
|
||||
"02801938": "basket",
|
||||
"02808440": "bathtub",
|
||||
"02818832": "bed",
|
||||
"02828884": "bench",
|
||||
"02876657": "bottle",
|
||||
"02880940": "bowl",
|
||||
"02924116": "bus",
|
||||
"02933112": "cabinet",
|
||||
"02747177": "can",
|
||||
"02942699": "camera",
|
||||
"02954340": "cap",
|
||||
"02958343": "car",
|
||||
"03001627": "chair",
|
||||
"03046257": "clock",
|
||||
"03207941": "dishwasher",
|
||||
"03211117": "monitor",
|
||||
"04379243": "table",
|
||||
"04401088": "telephone",
|
||||
"02946921": "tin_can",
|
||||
"04460130": "tower",
|
||||
"04468005": "train",
|
||||
"03085013": "keyboard",
|
||||
"03261776": "earphone",
|
||||
"03325088": "faucet",
|
||||
"03337140": "file",
|
||||
"03467517": "guitar",
|
||||
"03513137": "helmet",
|
||||
"03593526": "jar",
|
||||
"03624134": "knife",
|
||||
"03636649": "lamp",
|
||||
"03642806": "laptop",
|
||||
"03691459": "speaker",
|
||||
"03710193": "mailbox",
|
||||
"03759954": "microphone",
|
||||
"03761084": "microwave",
|
||||
"03790512": "motorcycle",
|
||||
"03797390": "mug",
|
||||
"03928116": "piano",
|
||||
"03938244": "pillow",
|
||||
"03948459": "pistol",
|
||||
"03991062": "pot",
|
||||
"04004475": "printer",
|
||||
"04074963": "remote_control",
|
||||
"04090263": "rifle",
|
||||
"04099429": "rocket",
|
||||
"04225987": "skateboard",
|
||||
"04256520": "sofa",
|
||||
"04330267": "stove",
|
||||
"04530566": "vessel",
|
||||
"04554684": "washer",
|
||||
"02992529": "cellphone",
|
||||
"02843684": "birdhouse",
|
||||
"02871439": "bookshelf",
|
||||
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
||||
# '02834778': 'bicycle', not in our taxonomy
|
||||
}
|
||||
|
@ -35,13 +69,23 @@ cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
|
|||
|
||||
|
||||
class Uniform15KPC(Dataset):
|
||||
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
|
||||
te_sample_size=10000, split='train', scale=1.,
|
||||
normalize_per_shape=False, box_per_shape=False,
|
||||
random_subsample=False,
|
||||
normalize_std_per_axis=False,
|
||||
all_points_mean=None, all_points_std=None,
|
||||
input_dim=3, use_mask=False):
|
||||
def __init__(
|
||||
self,
|
||||
root_dir,
|
||||
subdirs,
|
||||
tr_sample_size=10000,
|
||||
te_sample_size=10000,
|
||||
split="train",
|
||||
scale=1.0,
|
||||
normalize_per_shape=False,
|
||||
box_per_shape=False,
|
||||
random_subsample=False,
|
||||
normalize_std_per_axis=False,
|
||||
all_points_mean=None,
|
||||
all_points_std=None,
|
||||
input_dim=3,
|
||||
use_mask=False,
|
||||
):
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
self.in_tr_sample_size = tr_sample_size
|
||||
|
@ -67,9 +111,9 @@ class Uniform15KPC(Dataset):
|
|||
|
||||
all_mids = []
|
||||
for x in os.listdir(sub_path):
|
||||
if not x.endswith('.npy'):
|
||||
if not x.endswith(".npy"):
|
||||
continue
|
||||
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
|
||||
all_mids.append(os.path.join(self.split, x[: -len(".npy")]))
|
||||
|
||||
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
|
||||
for mid in all_mids:
|
||||
|
@ -111,7 +155,9 @@ class Uniform15KPC(Dataset):
|
|||
B, N = self.all_points.shape[:2]
|
||||
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
|
||||
|
||||
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim)
|
||||
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(
|
||||
axis=1
|
||||
).reshape(B, 1, input_dim)
|
||||
|
||||
else: # normalize across the dataset
|
||||
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
|
||||
|
@ -129,8 +175,7 @@ class Uniform15KPC(Dataset):
|
|||
self.tr_sample_size = min(10000, tr_sample_size)
|
||||
self.te_sample_size = min(5000, te_sample_size)
|
||||
print("Total number of data:%d" % len(self.train_points))
|
||||
print("Min number of points: (train)%d (test)%d"
|
||||
% (self.tr_sample_size, self.te_sample_size))
|
||||
print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size))
|
||||
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
||||
|
||||
def get_pc_stats(self, idx):
|
||||
|
@ -139,7 +184,6 @@ class Uniform15KPC(Dataset):
|
|||
s = self.all_points_std[idx].reshape(1, -1)
|
||||
return m, s
|
||||
|
||||
|
||||
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
|
||||
|
||||
def renormalize(self, mean, std):
|
||||
|
@ -173,11 +217,14 @@ class Uniform15KPC(Dataset):
|
|||
sid, mid = self.all_cate_mids[idx]
|
||||
|
||||
out = {
|
||||
'idx': idx,
|
||||
'train_points': tr_out,
|
||||
'test_points': te_out,
|
||||
'mean': m, 'std': s, 'cate_idx': cate_idx,
|
||||
'sid': sid, 'mid': mid
|
||||
"idx": idx,
|
||||
"train_points": tr_out,
|
||||
"test_points": te_out,
|
||||
"mean": m,
|
||||
"std": s,
|
||||
"cate_idx": cate_idx,
|
||||
"sid": sid,
|
||||
"mid": mid,
|
||||
}
|
||||
|
||||
if self.use_mask:
|
||||
|
@ -192,26 +239,35 @@ class Uniform15KPC(Dataset):
|
|||
# out['train_points_masked'] = masked
|
||||
# out['train_masks'] = tr_mask
|
||||
tr_mask = self.mask_transform(tr_out)
|
||||
out['train_masks'] = tr_mask
|
||||
out["train_masks"] = tr_mask
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShapeNet15kPointClouds(Uniform15KPC):
|
||||
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
|
||||
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
|
||||
split='train', scale=1., normalize_per_shape=False,
|
||||
normalize_std_per_axis=False, box_per_shape=False,
|
||||
random_subsample=False,
|
||||
all_points_mean=None, all_points_std=None,
|
||||
use_mask=False):
|
||||
def __init__(
|
||||
self,
|
||||
root_dir="data/ShapeNetCore.v2.PC15k",
|
||||
categories=["airplane"],
|
||||
tr_sample_size=10000,
|
||||
te_sample_size=2048,
|
||||
split="train",
|
||||
scale=1.0,
|
||||
normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
box_per_shape=False,
|
||||
random_subsample=False,
|
||||
all_points_mean=None,
|
||||
all_points_std=None,
|
||||
use_mask=False,
|
||||
):
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
assert self.split in ['train', 'test', 'val']
|
||||
assert self.split in ["train", "test", "val"]
|
||||
self.tr_sample_size = tr_sample_size
|
||||
self.te_sample_size = te_sample_size
|
||||
self.cates = categories
|
||||
if 'all' in categories:
|
||||
if "all" in categories:
|
||||
self.synset_ids = list(cate_to_synsetid.values())
|
||||
else:
|
||||
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
|
||||
|
@ -221,19 +277,21 @@ class ShapeNet15kPointClouds(Uniform15KPC):
|
|||
self.display_axis_order = [0, 2, 1]
|
||||
|
||||
super(ShapeNet15kPointClouds, self).__init__(
|
||||
root_dir, self.synset_ids,
|
||||
root_dir,
|
||||
self.synset_ids,
|
||||
tr_sample_size=tr_sample_size,
|
||||
te_sample_size=te_sample_size,
|
||||
split=split, scale=scale,
|
||||
normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape,
|
||||
split=split,
|
||||
scale=scale,
|
||||
normalize_per_shape=normalize_per_shape,
|
||||
box_per_shape=box_per_shape,
|
||||
normalize_std_per_axis=normalize_std_per_axis,
|
||||
random_subsample=random_subsample,
|
||||
all_points_mean=all_points_mean, all_points_std=all_points_std,
|
||||
input_dim=3, use_mask=use_mask)
|
||||
|
||||
|
||||
all_points_mean=all_points_mean,
|
||||
all_points_std=all_points_std,
|
||||
input_dim=3,
|
||||
use_mask=use_mask,
|
||||
)
|
||||
|
||||
|
||||
####################################################################################
|
||||
|
||||
|
||||
|
|
|
@ -1,34 +1,70 @@
|
|||
import hashlib
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import hashlib
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
synset_to_label = {
|
||||
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
|
||||
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
|
||||
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
|
||||
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
|
||||
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
|
||||
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
|
||||
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
|
||||
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
|
||||
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
|
||||
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
|
||||
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
|
||||
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
|
||||
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
|
||||
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
|
||||
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
|
||||
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
|
||||
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
|
||||
'04554684': 'washer', '02992529': 'cellphone',
|
||||
'02843684': 'birdhouse', '02871439': 'bookshelf',
|
||||
"02691156": "airplane",
|
||||
"02773838": "bag",
|
||||
"02801938": "basket",
|
||||
"02808440": "bathtub",
|
||||
"02818832": "bed",
|
||||
"02828884": "bench",
|
||||
"02876657": "bottle",
|
||||
"02880940": "bowl",
|
||||
"02924116": "bus",
|
||||
"02933112": "cabinet",
|
||||
"02747177": "can",
|
||||
"02942699": "camera",
|
||||
"02954340": "cap",
|
||||
"02958343": "car",
|
||||
"03001627": "chair",
|
||||
"03046257": "clock",
|
||||
"03207941": "dishwasher",
|
||||
"03211117": "monitor",
|
||||
"04379243": "table",
|
||||
"04401088": "telephone",
|
||||
"02946921": "tin_can",
|
||||
"04460130": "tower",
|
||||
"04468005": "train",
|
||||
"03085013": "keyboard",
|
||||
"03261776": "earphone",
|
||||
"03325088": "faucet",
|
||||
"03337140": "file",
|
||||
"03467517": "guitar",
|
||||
"03513137": "helmet",
|
||||
"03593526": "jar",
|
||||
"03624134": "knife",
|
||||
"03636649": "lamp",
|
||||
"03642806": "laptop",
|
||||
"03691459": "speaker",
|
||||
"03710193": "mailbox",
|
||||
"03759954": "microphone",
|
||||
"03761084": "microwave",
|
||||
"03790512": "motorcycle",
|
||||
"03797390": "mug",
|
||||
"03928116": "piano",
|
||||
"03938244": "pillow",
|
||||
"03948459": "pistol",
|
||||
"03991062": "pot",
|
||||
"04004475": "printer",
|
||||
"04074963": "remote_control",
|
||||
"04090263": "rifle",
|
||||
"04099429": "rocket",
|
||||
"04225987": "skateboard",
|
||||
"04256520": "sofa",
|
||||
"04330267": "stove",
|
||||
"04530566": "vessel",
|
||||
"04554684": "washer",
|
||||
"02992529": "cellphone",
|
||||
"02843684": "birdhouse",
|
||||
"02871439": "bookshelf",
|
||||
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
||||
# '02834778': 'bicycle', not in our taxonomy
|
||||
}
|
||||
|
@ -36,30 +72,44 @@ synset_to_label = {
|
|||
# Label to Synset mapping (for ShapeNet core classes)
|
||||
label_to_synset = {v: k for k, v in synset_to_label.items()}
|
||||
|
||||
|
||||
def _convert_categories(categories):
|
||||
assert categories is not None, 'List of categories cannot be empty!'
|
||||
if not (c in synset_to_label.keys() + label_to_synset.keys()
|
||||
for c in categories):
|
||||
warnings.warn('Some or all of the categories requested are not part of \
|
||||
ShapeNetCore. Data loading may fail if these categories are not avaliable.')
|
||||
synsets = [label_to_synset[c] if c in label_to_synset.keys()
|
||||
else c for c in categories]
|
||||
assert categories is not None, "List of categories cannot be empty!"
|
||||
if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories):
|
||||
warnings.warn(
|
||||
"Some or all of the categories requested are not part of \
|
||||
ShapeNetCore. Data loading may fail if these categories are not avaliable."
|
||||
)
|
||||
synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories]
|
||||
return synsets
|
||||
|
||||
|
||||
class ShapeNet_Multiview_Points(Dataset):
|
||||
def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val',
|
||||
npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False):
|
||||
def __init__(
|
||||
self,
|
||||
root_pc: str,
|
||||
root_views: str,
|
||||
cache: str,
|
||||
categories: list = ["chair"],
|
||||
split: str = "val",
|
||||
npoints=2048,
|
||||
sv_samples=800,
|
||||
all_points_mean=None,
|
||||
all_points_std=None,
|
||||
get_image=False,
|
||||
):
|
||||
self.root = Path(root_views)
|
||||
self.split = split
|
||||
self.get_image = get_image
|
||||
params = {
|
||||
'cat': categories,
|
||||
'npoints': npoints,
|
||||
'sv_samples': sv_samples,
|
||||
"cat": categories,
|
||||
"npoints": npoints,
|
||||
"sv_samples": sv_samples,
|
||||
}
|
||||
params = tuple(sorted(pair for pair in params.items()))
|
||||
self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest())
|
||||
self.cache_dir = Path(cache) / "svpoints/{}/{}".format(
|
||||
"_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest()
|
||||
)
|
||||
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.paths = []
|
||||
|
@ -74,13 +124,12 @@ class ShapeNet_Multiview_Points(Dataset):
|
|||
|
||||
# loops through desired classes
|
||||
for i in range(len(self.synsets)):
|
||||
|
||||
syn = self.synsets[i]
|
||||
class_target = self.root / syn
|
||||
if not class_target.exists():
|
||||
raise ValueError('Class {0} ({1}) was not found at location {2}.'.format(
|
||||
syn, self.labels[i], str(class_target)))
|
||||
|
||||
raise ValueError(
|
||||
"Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target))
|
||||
)
|
||||
|
||||
sub_path_pc = os.path.join(root_pc, syn, split)
|
||||
if not os.path.isdir(sub_path_pc):
|
||||
|
@ -90,30 +139,30 @@ class ShapeNet_Multiview_Points(Dataset):
|
|||
self.all_mids = []
|
||||
self.imgs = []
|
||||
for x in os.listdir(sub_path_pc):
|
||||
if not x.endswith('.npy'):
|
||||
if not x.endswith(".npy"):
|
||||
continue
|
||||
self.all_mids.append(os.path.join(split, x[:-len('.npy')]))
|
||||
self.all_mids.append(os.path.join(split, x[: -len(".npy")]))
|
||||
|
||||
for mid in tqdm(self.all_mids):
|
||||
# obj_fname = os.path.join(sub_path, x)
|
||||
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
|
||||
cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz'))
|
||||
cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz"))
|
||||
if len(cams_pths) < 20:
|
||||
continue
|
||||
point_cloud = np.load(obj_fname)
|
||||
sv_points_group = []
|
||||
img_path_group = []
|
||||
(self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True)
|
||||
(self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True)
|
||||
success = True
|
||||
for i, cp in enumerate(cams_pths):
|
||||
cp = str(cp)
|
||||
vp = cp.split('cam_params')[0] + 'depth.png'
|
||||
depth_minmax_pth = cp.split('_cam_params')[0] + '.npy'
|
||||
cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) )
|
||||
vp = cp.split("cam_params")[0] + "depth.png"
|
||||
depth_minmax_pth = cp.split("_cam_params")[0] + ".npy"
|
||||
cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth))
|
||||
|
||||
cam_params = np.load(cp)
|
||||
extr = cam_params['extr']
|
||||
intr = cam_params['intr']
|
||||
extr = cam_params["extr"]
|
||||
intr = cam_params["intr"]
|
||||
|
||||
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
|
||||
|
||||
|
@ -125,7 +174,7 @@ class ShapeNet_Multiview_Points(Dataset):
|
|||
sv_points_group.append(sv_point_cloud)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
success=False
|
||||
success = False
|
||||
break
|
||||
if not success:
|
||||
continue
|
||||
|
@ -144,64 +193,66 @@ class ShapeNet_Multiview_Points(Dataset):
|
|||
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
|
||||
|
||||
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
||||
self.train_points = self.all_points[:,:10000]
|
||||
self.test_points = self.all_points[:,10000:]
|
||||
self.train_points = self.all_points[:, :10000]
|
||||
self.test_points = self.all_points[:, 10000:]
|
||||
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
|
||||
|
||||
def get_pc_stats(self, idx):
|
||||
|
||||
return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)
|
||||
return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1)
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the dataset. """
|
||||
"""Returns the length of the dataset."""
|
||||
return len(self.all_points)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
|
||||
tr_out = self.train_points[index]
|
||||
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
|
||||
tr_out = tr_out[tr_idxs, :]
|
||||
|
||||
gt_points = self.test_points[index][:self.npoints]
|
||||
gt_points = self.test_points[index][: self.npoints]
|
||||
|
||||
m, s = self.get_pc_stats(index)
|
||||
|
||||
sv_points = self.all_points_sv[index]
|
||||
|
||||
idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False)
|
||||
idxs = np.arange(0, sv_points.shape[-2])[
|
||||
: self.sv_samples
|
||||
] # np.random.choice(sv_points.shape[0], 500, replace=False)
|
||||
|
||||
data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(),
|
||||
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1)
|
||||
data = torch.cat(
|
||||
[
|
||||
torch.from_numpy(sv_points[:, idxs]).float(),
|
||||
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
masks = torch.zeros_like(data)
|
||||
masks[:,:idxs.shape[0]] = 1
|
||||
masks[:, : idxs.shape[0]] = 1
|
||||
|
||||
res = {'train_points': torch.from_numpy(tr_out).float(),
|
||||
'test_points': torch.from_numpy(gt_points).float(),
|
||||
'sv_points': data,
|
||||
'masks': masks,
|
||||
'std': s, 'mean': m,
|
||||
'idx': index,
|
||||
'name':self.all_mids[index]
|
||||
}
|
||||
|
||||
if self.split != 'train' and self.get_image:
|
||||
res = {
|
||||
"train_points": torch.from_numpy(tr_out).float(),
|
||||
"test_points": torch.from_numpy(gt_points).float(),
|
||||
"sv_points": data,
|
||||
"masks": masks,
|
||||
"std": s,
|
||||
"mean": m,
|
||||
"idx": index,
|
||||
"name": self.all_mids[index],
|
||||
}
|
||||
|
||||
if self.split != "train" and self.get_image:
|
||||
img_lst = []
|
||||
for n in range(self.all_points_sv.shape[1]):
|
||||
|
||||
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]
|
||||
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3]
|
||||
|
||||
img_lst.append(img)
|
||||
|
||||
img = torch.stack(img_lst, dim=0)
|
||||
|
||||
res['image'] = img
|
||||
res["image"] = img
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
def _render(self, cache_path, depth_pth, depth_minmax_pth):
|
||||
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
|
||||
#
|
||||
|
@ -210,11 +261,9 @@ class ShapeNet_Multiview_Points(Dataset):
|
|||
if os.path.exists(cache_path):
|
||||
data = np.load(cache_path)
|
||||
else:
|
||||
|
||||
data, depth = self.transform(depth_pth, depth_minmax_pth)
|
||||
assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0])
|
||||
assert data.shape[0] > 600, "Only {} points found".format(data.shape[0])
|
||||
data = data[np.random.choice(data.shape[0], 600, replace=False)]
|
||||
np.save(cache_path, data)
|
||||
|
||||
return data
|
||||
|
||||
|
|
|
@ -1,25 +1,32 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
chamfer_found = importlib.find_loader("chamfer_2D") is not None
|
||||
if not chamfer_found:
|
||||
## Cool trick from https://github.com/chrdiller
|
||||
print("Jitting Chamfer 2D")
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_2D = load(name="chamfer_2D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
|
||||
])
|
||||
|
||||
chamfer_2D = load(
|
||||
name="chamfer_2D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer2D.cu"]),
|
||||
],
|
||||
)
|
||||
print("Loaded JIT 2D CUDA chamfer distance")
|
||||
|
||||
else:
|
||||
import chamfer_2D
|
||||
|
||||
print("Loaded compiled 2D CUDA chamfer distance")
|
||||
|
||||
|
||||
# Chamfer's distance module @thibaultgroueix
|
||||
# GPU tensors only
|
||||
class chamfer_2DFunction(Function):
|
||||
|
@ -57,9 +64,7 @@ class chamfer_2DFunction(Function):
|
|||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_2D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
chamfer_2D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
|
|
|
@ -2,15 +2,16 @@ from setuptools import setup
|
|||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_2D',
|
||||
name="chamfer_2D",
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_2D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
|
||||
]),
|
||||
CUDAExtension(
|
||||
"chamfer_2D",
|
||||
[
|
||||
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(__file__.split("/")[:-1] + ["chamfer2D.cu"]),
|
||||
],
|
||||
),
|
||||
],
|
||||
|
||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
|
|
@ -1,25 +1,30 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
chamfer_found = importlib.find_loader("chamfer_3D") is not None
|
||||
if not chamfer_found:
|
||||
## Cool trick from https://github.com/chrdiller
|
||||
print("Jitting Chamfer 3D")
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_3D = load(name="chamfer_3D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
||||
],
|
||||
|
||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],)
|
||||
chamfer_3D = load(
|
||||
name="chamfer_3D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer3D.cu"]),
|
||||
],
|
||||
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||
)
|
||||
print("Loaded JIT 3D CUDA chamfer distance")
|
||||
|
||||
else:
|
||||
import chamfer_3D
|
||||
|
||||
print("Loaded compiled 3D CUDA chamfer distance")
|
||||
|
||||
|
||||
|
@ -60,9 +65,7 @@ class chamfer_3DFunction(Function):
|
|||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_3D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
chamfer_3D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
|
@ -74,4 +77,3 @@ class chamfer_3DDist(nn.Module):
|
|||
input1 = input1.contiguous()
|
||||
input2 = input2.contiguous()
|
||||
return chamfer_3DFunction.apply(input1, input2)
|
||||
|
||||
|
|
|
@ -2,15 +2,16 @@ from setuptools import setup
|
|||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_3D',
|
||||
name="chamfer_3D",
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_3D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
|
||||
]),
|
||||
CUDAExtension(
|
||||
"chamfer_3D",
|
||||
[
|
||||
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(__file__.split("/")[:-1] + ["chamfer3D.cu"]),
|
||||
],
|
||||
),
|
||||
],
|
||||
|
||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
|
|
@ -1,24 +1,29 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
chamfer_found = importlib.find_loader("chamfer_5D") is not None
|
||||
if not chamfer_found:
|
||||
## Cool trick from https://github.com/chrdiller
|
||||
print("Jitting Chamfer 5D")
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_5D = load(name="chamfer_5D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]),
|
||||
])
|
||||
|
||||
chamfer_5D = load(
|
||||
name="chamfer_5D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer5D.cu"]),
|
||||
],
|
||||
)
|
||||
print("Loaded JIT 5D CUDA chamfer distance")
|
||||
|
||||
else:
|
||||
import chamfer_5D
|
||||
|
||||
print("Loaded compiled 5D CUDA chamfer distance")
|
||||
|
||||
|
||||
|
@ -59,9 +64,7 @@ class chamfer_5DFunction(Function):
|
|||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_5D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
chamfer_5D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
|
|
|
@ -2,15 +2,16 @@ from setuptools import setup
|
|||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_5D',
|
||||
name="chamfer_5D",
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_5D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']),
|
||||
]),
|
||||
CUDAExtension(
|
||||
"chamfer_5D",
|
||||
[
|
||||
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(__file__.split("/")[:-1] + ["chamfer5D.cu"]),
|
||||
],
|
||||
),
|
||||
],
|
||||
|
||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
|
|
@ -33,8 +33,7 @@ def distChamfer(a, b):
|
|||
xx = torch.pow(x, 2).sum(2)
|
||||
yy = torch.pow(y, 2).sum(2)
|
||||
zz = torch.bmm(x, y.transpose(2, 1))
|
||||
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
|
||||
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
|
||||
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
|
||||
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
|
||||
P = rx.transpose(2, 1) + ry - 2 * zz
|
||||
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
|
||||
|
||||
def fscore(dist1, dist2, threshold=0.001):
|
||||
"""
|
||||
Calculates the F-score between two point clouds with the corresponding threshold value.
|
||||
|
@ -14,4 +15,3 @@ def fscore(dist1, dist2, threshold=0.001):
|
|||
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
|
||||
fscore[torch.isnan(fscore)] = 0
|
||||
return fscore, precision_1, precision_2
|
||||
|
||||
|
|
|
@ -1,20 +1,23 @@
|
|||
import torch, time
|
||||
import time
|
||||
|
||||
import chamfer2D.dist_chamfer_2D
|
||||
import chamfer3D.dist_chamfer_3D
|
||||
import chamfer5D.dist_chamfer_5D
|
||||
import chamfer_python
|
||||
import torch
|
||||
|
||||
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
|
||||
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
|
||||
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
|
||||
|
||||
from torch.autograd import Variable
|
||||
from fscore import fscore
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def test_chamfer(distChamfer, dim):
|
||||
points1 = torch.rand(4, 100, dim).cuda()
|
||||
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
|
||||
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
|
||||
dist1, dist2, idx1, idx2 = distChamfer(points1, points2)
|
||||
|
||||
loss = torch.sum(dist1)
|
||||
loss.backward()
|
||||
|
@ -29,9 +32,9 @@ def test_chamfer(distChamfer, dim):
|
|||
xd1 = idx1 - myidx1
|
||||
xd2 = idx2 - myidx2
|
||||
assert (
|
||||
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
|
||||
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
|
||||
), "chamfer cuda and chamfer normal are not giving the same results"
|
||||
print(f"fscore :", fscore(dist1, dist2))
|
||||
print("fscore :", fscore(dist1, dist2))
|
||||
print("Unit test passed")
|
||||
|
||||
|
||||
|
@ -49,7 +52,6 @@ def timings(distChamfer, dim):
|
|||
loss.backward()
|
||||
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
||||
|
||||
|
||||
print("Timings : Start Pythonic version")
|
||||
start = time.time()
|
||||
for i in range(num_it):
|
||||
|
@ -61,9 +63,8 @@ def timings(distChamfer, dim):
|
|||
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
||||
|
||||
|
||||
|
||||
dims = [2,3,5]
|
||||
for i,cham in enumerate([cham2D, cham3D, cham5D]):
|
||||
dims = [2, 3, 5]
|
||||
for i, cham in enumerate([cham2D, cham3D, cham5D]):
|
||||
print(f"testing Chamfer {dims[i]}D")
|
||||
test_chamfer(cham, dims[i])
|
||||
timings(cham, dims[i])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
import emd_cuda
|
||||
import torch
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
|
@ -44,4 +44,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
|
|||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||
cost = cost / xyz1.shape[1]
|
||||
return cost
|
||||
|
||||
|
|
|
@ -9,19 +9,17 @@ Notes:
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
setup(
|
||||
name='emd_ext',
|
||||
name="emd_ext",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='emd_cuda',
|
||||
name="emd_cuda",
|
||||
sources=[
|
||||
'cuda/emd.cpp',
|
||||
'cuda/emd_kernel.cu',
|
||||
"cuda/emd.cpp",
|
||||
"cuda/emd_kernel.cu",
|
||||
],
|
||||
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
||||
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
from emd import earth_mover_distance
|
||||
|
||||
# gt
|
||||
|
@ -13,10 +12,12 @@ print(p2)
|
|||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
||||
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
||||
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
||||
print('gt_dist: ', gt_dist)
|
||||
gt_dist = (
|
||||
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
|
||||
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
|
||||
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
|
||||
)
|
||||
print("gt_dist: ", gt_dist)
|
||||
|
||||
gt_dist.backward()
|
||||
print(p1.grad)
|
||||
|
@ -41,4 +42,3 @@ print(loss)
|
|||
loss.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.linalg import norm
|
||||
from scipy.stats import entropy
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from numpy.linalg import norm
|
||||
|
||||
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
|
||||
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
|
||||
from metrics.ChamferDistancePytorch.fscore import fscore
|
||||
from tqdm import tqdm
|
||||
|
||||
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
|
||||
from metrics.ChamferDistancePytorch.fscore import fscore
|
||||
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
|
||||
|
||||
cham3D = chamfer_3DDist()
|
||||
|
||||
|
||||
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||
def distChamfer(a, b):
|
||||
x, y = a, b
|
||||
|
@ -22,11 +24,11 @@ def distChamfer(a, b):
|
|||
diag_ind = torch.arange(0, num_points).to(a).long()
|
||||
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
||||
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
||||
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
||||
P = rx.transpose(2, 1) + ry - 2 * zz
|
||||
return P.min(1)[0], P.min(2)[0]
|
||||
|
||||
|
||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
|
||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
|
||||
N_sample = sample_pcs.shape[0]
|
||||
N_ref = ref_pcs.shape[0]
|
||||
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
||||
|
@ -56,13 +58,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
|
|||
cd = torch.cat(cd_lst)
|
||||
emd = torch.cat(emd_lst)
|
||||
fs_lst = torch.cat(fs_lst).mean()
|
||||
results = {
|
||||
'MMD-CD': cd,
|
||||
'MMD-EMD': emd,
|
||||
'fscore': fs_lst
|
||||
}
|
||||
results = {"MMD-CD": cd, "MMD-EMD": emd, "fscore": fs_lst}
|
||||
return results
|
||||
|
||||
|
||||
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
||||
N_sample = sample_pcs.shape[0]
|
||||
N_ref = ref_pcs.shape[0]
|
||||
|
@ -107,7 +106,7 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
|
|||
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
|
||||
if sqrt:
|
||||
M = M.abs().sqrt()
|
||||
INFINITY = float('inf')
|
||||
INFINITY = float("inf")
|
||||
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
|
||||
|
||||
count = torch.zeros(n0 + n1).to(Mxx)
|
||||
|
@ -116,19 +115,21 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
|
|||
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
|
||||
|
||||
s = {
|
||||
'tp': (pred * label).sum(),
|
||||
'fp': (pred * (1 - label)).sum(),
|
||||
'fn': ((1 - pred) * label).sum(),
|
||||
'tn': ((1 - pred) * (1 - label)).sum(),
|
||||
"tp": (pred * label).sum(),
|
||||
"fp": (pred * (1 - label)).sum(),
|
||||
"fn": ((1 - pred) * label).sum(),
|
||||
"tn": ((1 - pred) * (1 - label)).sum(),
|
||||
}
|
||||
|
||||
s.update({
|
||||
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
|
||||
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
|
||||
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
|
||||
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
|
||||
'acc': torch.eq(label, pred).float().mean(),
|
||||
})
|
||||
s.update(
|
||||
{
|
||||
"precision": s["tp"] / (s["tp"] + s["fp"] + 1e-10),
|
||||
"recall": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
|
||||
"acc_t": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
|
||||
"acc_f": s["tn"] / (s["tn"] + s["fp"] + 1e-10),
|
||||
"acc": torch.eq(label, pred).float().mean(),
|
||||
}
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
|
@ -141,9 +142,9 @@ def lgan_mmd_cov(all_dist):
|
|||
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
|
||||
cov = torch.tensor(cov).to(all_dist)
|
||||
return {
|
||||
'lgan_mmd': mmd,
|
||||
'lgan_cov': cov,
|
||||
'lgan_mmd_smp': mmd_smp,
|
||||
"lgan_mmd": mmd,
|
||||
"lgan_cov": cov,
|
||||
"lgan_mmd_smp": mmd_smp,
|
||||
}
|
||||
|
||||
|
||||
|
@ -153,27 +154,19 @@ def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
|
|||
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
|
||||
|
||||
res_cd = lgan_mmd_cov(M_rs_cd.t())
|
||||
results.update({
|
||||
"%s-CD" % k: v for k, v in res_cd.items()
|
||||
})
|
||||
results.update({"%s-CD" % k: v for k, v in res_cd.items()})
|
||||
|
||||
res_emd = lgan_mmd_cov(M_rs_emd.t())
|
||||
results.update({
|
||||
"%s-EMD" % k: v for k, v in res_emd.items()
|
||||
})
|
||||
results.update({"%s-EMD" % k: v for k, v in res_emd.items()})
|
||||
|
||||
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
|
||||
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
|
||||
|
||||
# 1-NN results
|
||||
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
|
||||
results.update({
|
||||
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
|
||||
})
|
||||
results.update({"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if "acc" in k})
|
||||
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
|
||||
results.update({
|
||||
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
|
||||
})
|
||||
results.update({"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if "acc" in k})
|
||||
|
||||
return results
|
||||
|
||||
|
@ -227,11 +220,11 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
|
|||
bound = 0.5 + epsilon
|
||||
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
||||
if verbose:
|
||||
warnings.warn('Point-clouds are not in unit cube.')
|
||||
warnings.warn("Point-clouds are not in unit cube.")
|
||||
|
||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
|
||||
if verbose:
|
||||
warnings.warn('Point-clouds are not in unit sphere.')
|
||||
warnings.warn("Point-clouds are not in unit sphere.")
|
||||
|
||||
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
||||
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
||||
|
@ -260,9 +253,9 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
|
|||
|
||||
def jensen_shannon_divergence(P, Q):
|
||||
if np.any(P < 0) or np.any(Q < 0):
|
||||
raise ValueError('Negative values.')
|
||||
raise ValueError("Negative values.")
|
||||
if len(P) != len(Q):
|
||||
raise ValueError('Non equal size.')
|
||||
raise ValueError("Non equal size.")
|
||||
|
||||
P_ = P / np.sum(P) # Ensure probabilities.
|
||||
Q_ = Q / np.sum(Q)
|
||||
|
@ -275,7 +268,7 @@ def jensen_shannon_divergence(P, Q):
|
|||
res2 = _jsdiv(P_, Q_)
|
||||
|
||||
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
||||
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
||||
warnings.warn("Numerical values of two JSD methods don't agree.")
|
||||
|
||||
return res
|
||||
|
||||
|
@ -312,11 +305,9 @@ if __name__ == "__main__":
|
|||
r_dist = min_r.mean().cpu().detach().item()
|
||||
print(l_dist, r_dist)
|
||||
|
||||
|
||||
emd_batch = EMD(x.cuda(), y.cuda(), False)
|
||||
print(emd_batch.shape)
|
||||
print(emd_batch.mean().detach().item())
|
||||
|
||||
jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy())
|
||||
print(jsd)
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import functools
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
|
||||
|
||||
|
||||
def _linear_gn_relu(in_channels, out_channels):
|
||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
|
||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
|
||||
|
||||
|
||||
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
||||
|
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
|
|||
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
||||
|
||||
|
||||
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def create_pointnet_components(
|
||||
blocks,
|
||||
in_channels,
|
||||
embed_dim,
|
||||
with_se=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
layers, concat_channels = [], 0
|
||||
|
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
|
|||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
with_se=with_se, normalize=normalize, eps=eps)
|
||||
block = functools.partial(
|
||||
PVConv,
|
||||
kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution),
|
||||
attention=attention,
|
||||
with_se=with_se,
|
||||
normalize=normalize,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
if c == 0:
|
||||
layers.append(block(in_channels, out_channels))
|
||||
else:
|
||||
layers.append(block(in_channels+embed_dim, out_channels))
|
||||
layers.append(block(in_channels + embed_dim, out_channels))
|
||||
in_channels = out_channels
|
||||
concat_channels += out_channels
|
||||
c += 1
|
||||
return layers, in_channels, concat_channels
|
||||
|
||||
|
||||
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
|
||||
dropout=0.1, with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def create_pointnet2_sa_components(
|
||||
sa_blocks,
|
||||
extra_feature_channels,
|
||||
embed_dim=64,
|
||||
use_att=False,
|
||||
dropout=0.1,
|
||||
with_se=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
in_channels = extra_feature_channels + 3
|
||||
|
||||
|
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
|||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0
|
||||
attention = (c + 1) % 2 == 0 and c > 0 and use_att and p == 0
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se and not attention, with_se_relu=True,
|
||||
normalize=normalize, eps=eps)
|
||||
block = functools.partial(
|
||||
PVConv,
|
||||
kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution),
|
||||
attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se and not attention,
|
||||
with_se_relu=True,
|
||||
normalize=normalize,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
if c == 0:
|
||||
sa_blocks.append(block(in_channels, out_channels))
|
||||
elif k ==0:
|
||||
sa_blocks.append(block(in_channels+embed_dim, out_channels))
|
||||
elif k == 0:
|
||||
sa_blocks.append(block(in_channels + embed_dim, out_channels))
|
||||
in_channels = out_channels
|
||||
k += 1
|
||||
extra_feature_channels = in_channels
|
||||
|
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
|||
if num_centers is None:
|
||||
block = PointNetAModule
|
||||
else:
|
||||
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
|
||||
num_neighbors=num_neighbors)
|
||||
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
|
||||
include_coordinates=True))
|
||||
block = functools.partial(
|
||||
PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
|
||||
)
|
||||
sa_blocks.append(
|
||||
block(
|
||||
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
|
||||
out_channels=out_channels,
|
||||
include_coordinates=True,
|
||||
)
|
||||
)
|
||||
c += 1
|
||||
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
||||
if len(sa_blocks) == 1:
|
||||
|
@ -127,10 +165,20 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
|||
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
||||
|
||||
|
||||
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False,
|
||||
dropout=0.1,
|
||||
with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def create_pointnet2_fp_modules(
|
||||
fp_blocks,
|
||||
in_channels,
|
||||
sa_in_channels,
|
||||
sv_points,
|
||||
embed_dim=64,
|
||||
use_att=False,
|
||||
dropout=0.1,
|
||||
with_se=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
fp_layers = []
|
||||
|
@ -139,7 +187,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
|
|||
fp_blocks = []
|
||||
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
||||
fp_blocks.append(
|
||||
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
|
||||
PointNetFPModule(
|
||||
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
|
||||
)
|
||||
)
|
||||
in_channels = out_channels[-1]
|
||||
|
||||
|
@ -151,9 +201,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
|
|||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps)
|
||||
block = functools.partial(
|
||||
PVConv,
|
||||
kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution),
|
||||
attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se and not attention,
|
||||
with_se_relu=True,
|
||||
normalize=normalize,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
fp_blocks.append(block(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
@ -168,9 +226,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
|
|||
|
||||
|
||||
class PVCNN2Base(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1,
|
||||
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
sv_points,
|
||||
embed_dim,
|
||||
use_att,
|
||||
dropout=0.1,
|
||||
extra_feature_channels=3,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
super().__init__()
|
||||
assert extra_feature_channels >= 0
|
||||
self.embed_dim = embed_dim
|
||||
|
@ -178,9 +244,14 @@ class PVCNN2Base(nn.Module):
|
|||
self.in_channels = extra_feature_channels + 3
|
||||
|
||||
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
|
||||
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
|
||||
use_att=use_att, dropout=dropout,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
||||
sa_blocks=self.sa_blocks,
|
||||
extra_feature_channels=extra_feature_channels,
|
||||
with_se=True,
|
||||
embed_dim=embed_dim,
|
||||
use_att=use_att,
|
||||
dropout=dropout,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
)
|
||||
self.sa_layers = nn.ModuleList(sa_layers)
|
||||
|
||||
|
@ -189,16 +260,26 @@ class PVCNN2Base(nn.Module):
|
|||
# only use extra features in the last fp module
|
||||
sa_in_channels[0] = extra_feature_channels
|
||||
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
||||
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points,
|
||||
with_se=True, embed_dim=embed_dim,
|
||||
use_att=use_att, dropout=dropout,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
||||
fp_blocks=self.fp_blocks,
|
||||
in_channels=channels_sa_features,
|
||||
sa_in_channels=sa_in_channels,
|
||||
sv_points=sv_points,
|
||||
with_se=True,
|
||||
embed_dim=embed_dim,
|
||||
use_att=use_att,
|
||||
dropout=dropout,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
)
|
||||
self.fp_layers = nn.ModuleList(fp_layers)
|
||||
|
||||
|
||||
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes],
|
||||
classifier=True, dim=2, width_multiplier=width_multiplier)
|
||||
layers, _ = create_mlp_components(
|
||||
in_channels=channels_fp_features,
|
||||
out_channels=[128, 0.5, num_classes],
|
||||
classifier=True,
|
||||
dim=2,
|
||||
width_multiplier=width_multiplier,
|
||||
)
|
||||
self.classifier = nn.Sequential(*layers)
|
||||
|
||||
self.embedf = nn.Sequential(
|
||||
|
@ -223,31 +304,30 @@ class PVCNN2Base(nn.Module):
|
|||
return emb
|
||||
|
||||
def forward(self, inputs, t):
|
||||
|
||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
|
||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
|
||||
|
||||
# inputs : [B, in_channels + S, N]
|
||||
coords, features = inputs[:, :3, :].contiguous(), inputs
|
||||
coords_list, in_features_list = [], []
|
||||
for i, sa_blocks in enumerate(self.sa_layers):
|
||||
for i, sa_blocks in enumerate(self.sa_layers):
|
||||
in_features_list.append(features)
|
||||
coords_list.append(coords)
|
||||
if i == 0:
|
||||
features, coords, temb = sa_blocks ((features, coords, temb))
|
||||
features, coords, temb = sa_blocks((features, coords, temb))
|
||||
else:
|
||||
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
|
||||
features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
|
||||
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
||||
if self.global_att is not None:
|
||||
features = self.global_att(features)
|
||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||
jump_coords = coords_list[-1 - fp_idx]
|
||||
fump_feats = in_features_list[-1-fp_idx]
|
||||
fump_feats = in_features_list[-1 - fp_idx]
|
||||
# if fp_idx == len(self.fp_layers) - 1:
|
||||
# jump_coords = jump_coords[:,:,self.sv_points:]
|
||||
# fump_feats = fump_feats[:,:,self.sv_points:]
|
||||
|
||||
features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb))
|
||||
features, coords, temb = fp_blocks(
|
||||
(jump_coords, coords, torch.cat([features, temb], dim=1), fump_feats, temb)
|
||||
)
|
||||
|
||||
return self.classifier(features)
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import functools
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
|
||||
|
||||
|
||||
def _linear_gn_relu(in_channels, out_channels):
|
||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
|
||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
|
||||
|
||||
|
||||
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
||||
|
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
|
|||
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
||||
|
||||
|
||||
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def create_pointnet_components(
|
||||
blocks,
|
||||
in_channels,
|
||||
embed_dim,
|
||||
with_se=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
layers, concat_channels = [], 0
|
||||
|
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
|
|||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
with_se=with_se, normalize=normalize, eps=eps)
|
||||
block = functools.partial(
|
||||
PVConv,
|
||||
kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution),
|
||||
attention=attention,
|
||||
with_se=with_se,
|
||||
normalize=normalize,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
if c == 0:
|
||||
layers.append(block(in_channels, out_channels))
|
||||
else:
|
||||
layers.append(block(in_channels+embed_dim, out_channels))
|
||||
layers.append(block(in_channels + embed_dim, out_channels))
|
||||
in_channels = out_channels
|
||||
concat_channels += out_channels
|
||||
c += 1
|
||||
return layers, in_channels, concat_channels
|
||||
|
||||
|
||||
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
|
||||
dropout=0.1, with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def create_pointnet2_sa_components(
|
||||
sa_blocks,
|
||||
extra_feature_channels,
|
||||
embed_dim=64,
|
||||
use_att=False,
|
||||
dropout=0.1,
|
||||
with_se=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
in_channels = extra_feature_channels + 3
|
||||
|
||||
|
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
|||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = (c+1) % 2 == 0 and use_att and p == 0
|
||||
attention = (c + 1) % 2 == 0 and use_att and p == 0
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se, with_se_relu=True,
|
||||
normalize=normalize, eps=eps)
|
||||
block = functools.partial(
|
||||
PVConv,
|
||||
kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution),
|
||||
attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se,
|
||||
with_se_relu=True,
|
||||
normalize=normalize,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
if c == 0:
|
||||
sa_blocks.append(block(in_channels, out_channels))
|
||||
elif k ==0:
|
||||
sa_blocks.append(block(in_channels+embed_dim, out_channels))
|
||||
elif k == 0:
|
||||
sa_blocks.append(block(in_channels + embed_dim, out_channels))
|
||||
in_channels = out_channels
|
||||
k += 1
|
||||
extra_feature_channels = in_channels
|
||||
|
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
|||
if num_centers is None:
|
||||
block = PointNetAModule
|
||||
else:
|
||||
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
|
||||
num_neighbors=num_neighbors)
|
||||
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
|
||||
include_coordinates=True))
|
||||
block = functools.partial(
|
||||
PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
|
||||
)
|
||||
sa_blocks.append(
|
||||
block(
|
||||
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
|
||||
out_channels=out_channels,
|
||||
include_coordinates=True,
|
||||
)
|
||||
)
|
||||
c += 1
|
||||
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
||||
if len(sa_blocks) == 1:
|
||||
|
@ -127,10 +165,19 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
|||
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
||||
|
||||
|
||||
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
|
||||
dropout=0.1,
|
||||
with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def create_pointnet2_fp_modules(
|
||||
fp_blocks,
|
||||
in_channels,
|
||||
sa_in_channels,
|
||||
embed_dim=64,
|
||||
use_att=False,
|
||||
dropout=0.1,
|
||||
with_se=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
fp_layers = []
|
||||
|
@ -139,7 +186,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
|
|||
fp_blocks = []
|
||||
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
||||
fp_blocks.append(
|
||||
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
|
||||
PointNetFPModule(
|
||||
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
|
||||
)
|
||||
)
|
||||
in_channels = out_channels[-1]
|
||||
|
||||
|
@ -147,14 +196,21 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
|
|||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
|
||||
attention = (c + 1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se, with_se_relu=True,
|
||||
normalize=normalize, eps=eps)
|
||||
block = functools.partial(
|
||||
PVConv,
|
||||
kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution),
|
||||
attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se,
|
||||
with_se_relu=True,
|
||||
normalize=normalize,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
fp_blocks.append(block(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
|
@ -168,20 +224,31 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
|
|||
return fp_layers, in_channels
|
||||
|
||||
|
||||
|
||||
class PVCNN2Base(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, embed_dim, use_att, dropout=0.1,
|
||||
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
embed_dim,
|
||||
use_att,
|
||||
dropout=0.1,
|
||||
extra_feature_channels=3,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
super().__init__()
|
||||
assert extra_feature_channels >= 0
|
||||
self.embed_dim = embed_dim
|
||||
self.in_channels = extra_feature_channels + 3
|
||||
|
||||
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
|
||||
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
|
||||
use_att=use_att, dropout=dropout,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
||||
sa_blocks=self.sa_blocks,
|
||||
extra_feature_channels=extra_feature_channels,
|
||||
with_se=True,
|
||||
embed_dim=embed_dim,
|
||||
use_att=use_att,
|
||||
dropout=dropout,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
)
|
||||
self.sa_layers = nn.ModuleList(sa_layers)
|
||||
|
||||
|
@ -190,15 +257,25 @@ class PVCNN2Base(nn.Module):
|
|||
# only use extra features in the last fp module
|
||||
sa_in_channels[0] = extra_feature_channels
|
||||
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
||||
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim,
|
||||
use_att=use_att, dropout=dropout,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
||||
fp_blocks=self.fp_blocks,
|
||||
in_channels=channels_sa_features,
|
||||
sa_in_channels=sa_in_channels,
|
||||
with_se=True,
|
||||
embed_dim=embed_dim,
|
||||
use_att=use_att,
|
||||
dropout=dropout,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
)
|
||||
self.fp_layers = nn.ModuleList(fp_layers)
|
||||
|
||||
|
||||
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5
|
||||
classifier=True, dim=2, width_multiplier=width_multiplier)
|
||||
layers, _ = create_mlp_components(
|
||||
in_channels=channels_fp_features,
|
||||
out_channels=[128, dropout, num_classes], # was 0.5
|
||||
classifier=True,
|
||||
dim=2,
|
||||
width_multiplier=width_multiplier,
|
||||
)
|
||||
self.classifier = nn.Sequential(*layers)
|
||||
|
||||
self.embedf = nn.Sequential(
|
||||
|
@ -223,25 +300,30 @@ class PVCNN2Base(nn.Module):
|
|||
return emb
|
||||
|
||||
def forward(self, inputs, t):
|
||||
|
||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
|
||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
|
||||
|
||||
# inputs : [B, in_channels + S, N]
|
||||
coords, features = inputs[:, :3, :].contiguous(), inputs
|
||||
coords_list, in_features_list = [], []
|
||||
for i, sa_blocks in enumerate(self.sa_layers):
|
||||
for i, sa_blocks in enumerate(self.sa_layers):
|
||||
in_features_list.append(features)
|
||||
coords_list.append(coords)
|
||||
if i == 0:
|
||||
features, coords, temb = sa_blocks ((features, coords, temb))
|
||||
features, coords, temb = sa_blocks((features, coords, temb))
|
||||
else:
|
||||
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
|
||||
features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
|
||||
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
||||
if self.global_att is not None:
|
||||
features = self.global_att(features)
|
||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||
features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb))
|
||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||
features, coords, temb = fp_blocks(
|
||||
(
|
||||
coords_list[-1 - fp_idx],
|
||||
coords,
|
||||
torch.cat([features, temb], dim=1),
|
||||
in_features_list[-1 - fp_idx],
|
||||
temb,
|
||||
)
|
||||
)
|
||||
|
||||
return self.classifier(features)
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
from modules.ball_query import BallQuery
|
||||
from modules.frustum import FrustumPointNetLoss
|
||||
from modules.loss import KLLoss
|
||||
from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
|
||||
from modules.pvconv import PVConv, Attention, Swish, PVConvReLU
|
||||
from modules.se import SE3d
|
||||
from modules.shared_mlp import SharedMLP
|
||||
from modules.voxelization import Voxelization
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
|
||||
import modules.functional as F
|
||||
|
||||
__all__ = ['BallQuery']
|
||||
__all__ = ["BallQuery"]
|
||||
|
||||
|
||||
class BallQuery(nn.Module):
|
||||
|
@ -21,7 +21,7 @@ class BallQuery(nn.Module):
|
|||
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
||||
|
||||
if points_features is None:
|
||||
assert self.include_coordinates, 'No Features For Grouping'
|
||||
assert self.include_coordinates, "No Features For Grouping"
|
||||
neighbor_features = neighbor_coordinates
|
||||
else:
|
||||
neighbor_features = F.grouping(points_features, neighbor_indices)
|
||||
|
@ -30,5 +30,6 @@ class BallQuery(nn.Module):
|
|||
return neighbor_features, F.grouping(temb, neighbor_indices)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'radius={}, num_neighbors={}{}'.format(
|
||||
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
||||
return "radius={}, num_neighbors={}{}".format(
|
||||
self.radius, self.num_neighbors, ", include coordinates" if self.include_coordinates else ""
|
||||
)
|
||||
|
|
|
@ -5,12 +5,20 @@ import torch.nn.functional as F
|
|||
|
||||
import modules.functional as PF
|
||||
|
||||
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
|
||||
__all__ = ["FrustumPointNetLoss", "get_box_corners_3d"]
|
||||
|
||||
|
||||
class FrustumPointNetLoss(nn.Module):
|
||||
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
|
||||
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
|
||||
def __init__(
|
||||
self,
|
||||
num_heading_angle_bins,
|
||||
num_size_templates,
|
||||
size_templates,
|
||||
box_loss_weight=1.0,
|
||||
corners_loss_weight=10.0,
|
||||
heading_residual_loss_weight=20.0,
|
||||
size_residual_loss_weight=20.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.box_loss_weight = box_loss_weight
|
||||
self.corners_loss_weight = corners_loss_weight
|
||||
|
@ -19,28 +27,28 @@ class FrustumPointNetLoss(nn.Module):
|
|||
|
||||
self.num_heading_angle_bins = num_heading_angle_bins
|
||||
self.num_size_templates = num_size_templates
|
||||
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
|
||||
self.register_buffer("size_templates", size_templates.view(self.num_size_templates, 3))
|
||||
self.register_buffer(
|
||||
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
||||
"heading_angle_bin_centers", torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
||||
)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
mask_logits = inputs['mask_logits'] # (B, 2, N)
|
||||
center_reg = inputs['center_reg'] # (B, 3)
|
||||
center = inputs['center'] # (B, 3)
|
||||
heading_scores = inputs['heading_scores'] # (B, NH)
|
||||
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
|
||||
heading_residuals = inputs['heading_residuals'] # (B, NH)
|
||||
size_scores = inputs['size_scores'] # (B, NS)
|
||||
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
|
||||
size_residuals = inputs['size_residuals'] # (B, NS, 3)
|
||||
mask_logits = inputs["mask_logits"] # (B, 2, N)
|
||||
center_reg = inputs["center_reg"] # (B, 3)
|
||||
center = inputs["center"] # (B, 3)
|
||||
heading_scores = inputs["heading_scores"] # (B, NH)
|
||||
heading_residuals_normalized = inputs["heading_residuals_normalized"] # (B, NH)
|
||||
heading_residuals = inputs["heading_residuals"] # (B, NH)
|
||||
size_scores = inputs["size_scores"] # (B, NS)
|
||||
size_residuals_normalized = inputs["size_residuals_normalized"] # (B, NS, 3)
|
||||
size_residuals = inputs["size_residuals"] # (B, NS, 3)
|
||||
|
||||
mask_logits_target = targets['mask_logits'] # (B, N)
|
||||
center_target = targets['center'] # (B, 3)
|
||||
heading_bin_id_target = targets['heading_bin_id'] # (B, )
|
||||
heading_residual_target = targets['heading_residual'] # (B, )
|
||||
size_template_id_target = targets['size_template_id'] # (B, )
|
||||
size_residual_target = targets['size_residual'] # (B, 3)
|
||||
mask_logits_target = targets["mask_logits"] # (B, N)
|
||||
center_target = targets["center"] # (B, 3)
|
||||
heading_bin_id_target = targets["heading_bin_id"] # (B, )
|
||||
heading_residual_target = targets["heading_residual"] # (B, )
|
||||
size_template_id_target = targets["size_template_id"] # (B, )
|
||||
size_residual_target = targets["size_residual"] # (B, 3)
|
||||
|
||||
batch_size = center.size(0)
|
||||
batch_id = torch.arange(batch_size, device=center.device)
|
||||
|
@ -65,25 +73,32 @@ class FrustumPointNetLoss(nn.Module):
|
|||
)
|
||||
|
||||
# Bounding box losses
|
||||
heading = (heading_residuals[batch_id, heading_bin_id_target]
|
||||
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
|
||||
heading = (
|
||||
heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target]
|
||||
) # (B, )
|
||||
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
||||
size = (size_residuals[batch_id, size_template_id_target]
|
||||
+ self.size_templates[size_template_id_target]) # (B, 3)
|
||||
size = (
|
||||
size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target]
|
||||
) # (B, 3)
|
||||
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
||||
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
||||
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
||||
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
|
||||
sizes=size_target, with_flip=True) # (B, 3, 8)
|
||||
corners_loss = PF.huber_loss(torch.min(
|
||||
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
|
||||
), delta=1.0)
|
||||
corners_target, corners_target_flip = get_box_corners_3d(
|
||||
centers=center_target, headings=heading_target, sizes=size_target, with_flip=True
|
||||
) # (B, 3, 8)
|
||||
corners_loss = PF.huber_loss(
|
||||
torch.min(torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)),
|
||||
delta=1.0,
|
||||
)
|
||||
# Summing up
|
||||
loss = mask_loss + self.box_loss_weight * (
|
||||
center_loss + center_reg_loss + heading_loss + size_loss
|
||||
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
||||
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
||||
+ self.corners_loss_weight * corners_loss
|
||||
center_loss
|
||||
+ center_reg_loss
|
||||
+ heading_loss
|
||||
+ size_loss
|
||||
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
||||
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
||||
+ self.corners_loss_weight * corners_loss
|
||||
)
|
||||
|
||||
return loss
|
||||
|
@ -105,9 +120,9 @@ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
|
|||
l = sizes[:, 0] # (N,)
|
||||
w = sizes[:, 1] # (N,)
|
||||
h = sizes[:, 2] # (N,)
|
||||
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
|
||||
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
|
||||
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
|
||||
x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
|
||||
y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
|
||||
z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
|
||||
|
||||
c = torch.cos(headings) # (N,)
|
||||
s = torch.sin(headings) # (N,)
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
from modules.functional.ball_query import ball_query
|
||||
from modules.functional.devoxelization import trilinear_devoxelize
|
||||
from modules.functional.grouping import grouping
|
||||
from modules.functional.interpolatation import nearest_neighbor_interpolate
|
||||
from modules.functional.loss import kl_loss, huber_loss
|
||||
from modules.functional.sampling import gather, furthest_point_sample, logits_mask
|
||||
from modules.functional.voxelization import avg_voxelize
|
|
@ -3,24 +3,28 @@ import os
|
|||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
_backend = load(name='_pvcnn_backend',
|
||||
extra_cflags=['-O3', '-std=c++17'],
|
||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
||||
sources=[os.path.join(_src_path,'src', f) for f in [
|
||||
'ball_query/ball_query.cpp',
|
||||
'ball_query/ball_query.cu',
|
||||
'grouping/grouping.cpp',
|
||||
'grouping/grouping.cu',
|
||||
'interpolate/neighbor_interpolate.cpp',
|
||||
'interpolate/neighbor_interpolate.cu',
|
||||
'interpolate/trilinear_devox.cpp',
|
||||
'interpolate/trilinear_devox.cu',
|
||||
'sampling/sampling.cpp',
|
||||
'sampling/sampling.cu',
|
||||
'voxelization/vox.cpp',
|
||||
'voxelization/vox.cu',
|
||||
'bindings.cpp',
|
||||
]]
|
||||
)
|
||||
_backend = load(
|
||||
name="_pvcnn_backend",
|
||||
extra_cflags=["-O3", "-std=c++17"],
|
||||
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||
sources=[
|
||||
os.path.join(_src_path, "src", f)
|
||||
for f in [
|
||||
"ball_query/ball_query.cpp",
|
||||
"ball_query/ball_query.cu",
|
||||
"grouping/grouping.cpp",
|
||||
"grouping/grouping.cu",
|
||||
"interpolate/neighbor_interpolate.cpp",
|
||||
"interpolate/neighbor_interpolate.cu",
|
||||
"interpolate/trilinear_devox.cpp",
|
||||
"interpolate/trilinear_devox.cu",
|
||||
"sampling/sampling.cpp",
|
||||
"sampling/sampling.cu",
|
||||
"voxelization/vox.cpp",
|
||||
"voxelization/vox.cu",
|
||||
"bindings.cpp",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
||||
__all__ = ["_backend"]
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
from torch.autograd import Function
|
||||
|
||||
from modules.functional.backend import _backend
|
||||
|
||||
__all__ = ['ball_query']
|
||||
__all__ = ["ball_query"]
|
||||
|
||||
|
||||
def ball_query(centers_coords, points_coords, radius, num_neighbors):
|
||||
"""
|
||||
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
||||
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
||||
:param radius: float, radius of ball query
|
||||
:param num_neighbors: int, maximum number of neighbors
|
||||
:return:
|
||||
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
|
||||
"""
|
||||
centers_coords = centers_coords.contiguous()
|
||||
points_coords = points_coords.contiguous()
|
||||
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
|
||||
"""
|
||||
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
||||
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
||||
:param radius: float, radius of ball query
|
||||
:param num_neighbors: int, maximum number of neighbors
|
||||
:return:
|
||||
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
|
||||
"""
|
||||
centers_coords = centers_coords.contiguous()
|
||||
points_coords = points_coords.contiguous()
|
||||
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
|
||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
|||
|
||||
from modules.functional.backend import _backend
|
||||
|
||||
__all__ = ['trilinear_devoxelize']
|
||||
__all__ = ["trilinear_devoxelize"]
|
||||
|
||||
|
||||
class TrilinearDevoxelization(Function):
|
||||
|
@ -29,7 +29,7 @@ class TrilinearDevoxelization(Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
:param ctx:
|
||||
:param ctx:
|
||||
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
|
||||
:return:
|
||||
gradient of inputs, FloatTensor[B, C, R, R, R]
|
||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
|||
|
||||
from modules.functional.backend import _backend
|
||||
|
||||
__all__ = ['grouping']
|
||||
__all__ = ["grouping"]
|
||||
|
||||
|
||||
class Grouping(Function):
|
||||
|
@ -23,7 +23,7 @@ class Grouping(Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
(indices,) = ctx.saved_tensors
|
||||
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
|
||||
return grad_features, None
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
|||
|
||||
from modules.functional.backend import _backend
|
||||
|
||||
__all__ = ['nearest_neighbor_interpolate']
|
||||
__all__ = ["nearest_neighbor_interpolate"]
|
||||
|
||||
|
||||
class NeighborInterpolation(Function):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['kl_loss', 'huber_loss']
|
||||
__all__ = ["kl_loss", "huber_loss"]
|
||||
|
||||
|
||||
def kl_loss(x, y):
|
||||
|
@ -13,5 +13,5 @@ def kl_loss(x, y):
|
|||
def huber_loss(error, delta):
|
||||
abs_error = torch.abs(error)
|
||||
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
|
||||
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
|
||||
losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic)
|
||||
return torch.mean(losses)
|
||||
|
|
|
@ -4,7 +4,7 @@ from torch.autograd import Function
|
|||
|
||||
from modules.functional.backend import _backend
|
||||
|
||||
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
|
||||
__all__ = ["gather", "furthest_point_sample", "logits_mask"]
|
||||
|
||||
|
||||
class Gather(Function):
|
||||
|
@ -26,7 +26,7 @@ class Gather(Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
(indices,) = ctx.saved_tensors
|
||||
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
|
||||
return grad_features, None
|
||||
|
||||
|
@ -60,11 +60,12 @@ def logits_mask(coords, logits, num_points_per_object):
|
|||
mask: mask to select points, BoolTensor[B, N]
|
||||
"""
|
||||
batch_size, _, num_points = coords.shape
|
||||
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
||||
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
||||
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
|
||||
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
|
||||
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
|
||||
torch.ones_like(num_candidates)).float() # [B, C]
|
||||
masked_coords_mean = (
|
||||
torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, torch.ones_like(num_candidates)).float()
|
||||
) # [B, C]
|
||||
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
|
||||
for i in range(batch_size):
|
||||
current_mask = mask[i] # [N]
|
||||
|
@ -74,10 +75,14 @@ def logits_mask(coords, logits, num_points_per_object):
|
|||
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
|
||||
selected_indices[i] = current_candidates[choices]
|
||||
elif current_num_candidates > 0:
|
||||
choices = np.concatenate([
|
||||
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
|
||||
np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
|
||||
])
|
||||
choices = np.concatenate(
|
||||
[
|
||||
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
|
||||
np.random.choice(
|
||||
current_num_candidates, num_points_per_object % current_num_candidates, replace=False
|
||||
),
|
||||
]
|
||||
)
|
||||
np.random.shuffle(choices)
|
||||
selected_indices[i] = current_candidates[choices]
|
||||
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
|
||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
|||
|
||||
from modules.functional.backend import _backend
|
||||
|
||||
__all__ = ['avg_voxelize']
|
||||
__all__ = ["avg_voxelize"]
|
||||
|
||||
|
||||
class AvgVoxelization(Function):
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch.nn as nn
|
|||
|
||||
import modules.functional as F
|
||||
|
||||
__all__ = ['KLLoss']
|
||||
__all__ = ["KLLoss"]
|
||||
|
||||
|
||||
class KLLoss(nn.Module):
|
||||
|
|
|
@ -5,7 +5,7 @@ import modules.functional as F
|
|||
from modules.ball_query import BallQuery
|
||||
from modules.shared_mlp import SharedMLP
|
||||
|
||||
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
|
||||
__all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"]
|
||||
|
||||
|
||||
class PointNetAModule(nn.Module):
|
||||
|
@ -20,8 +20,9 @@ class PointNetAModule(nn.Module):
|
|||
total_out_channels = 0
|
||||
for _out_channels in out_channels:
|
||||
mlps.append(
|
||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
||||
out_channels=_out_channels, dim=1)
|
||||
SharedMLP(
|
||||
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1
|
||||
)
|
||||
)
|
||||
total_out_channels += _out_channels[-1]
|
||||
|
||||
|
@ -43,7 +44,7 @@ class PointNetAModule(nn.Module):
|
|||
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
|
||||
|
||||
def extra_repr(self):
|
||||
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
||||
return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}"
|
||||
|
||||
|
||||
class PointNetSAModule(nn.Module):
|
||||
|
@ -67,8 +68,9 @@ class PointNetSAModule(nn.Module):
|
|||
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
|
||||
)
|
||||
mlps.append(
|
||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
||||
out_channels=_out_channels, dim=2)
|
||||
SharedMLP(
|
||||
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2
|
||||
)
|
||||
)
|
||||
total_out_channels += _out_channels[-1]
|
||||
|
||||
|
@ -90,7 +92,7 @@ class PointNetSAModule(nn.Module):
|
|||
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
|
||||
|
||||
def extra_repr(self):
|
||||
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
||||
return f"num_centers={self.num_centers}, out_channels={self.out_channels}"
|
||||
|
||||
|
||||
class PointNetFPModule(nn.Module):
|
||||
|
@ -107,7 +109,5 @@ class PointNetFPModule(nn.Module):
|
|||
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
||||
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
|
||||
if points_features is not None:
|
||||
interpolated_features = torch.cat(
|
||||
[interpolated_features, points_features], dim=1
|
||||
)
|
||||
interpolated_features = torch.cat([interpolated_features, points_features], dim=1)
|
||||
return self.mlp(interpolated_features), points_coords, interpolated_temb
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import modules.functional as F
|
||||
from modules.voxelization import Voxelization
|
||||
from modules.shared_mlp import SharedMLP
|
||||
from modules.se import SE3d
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU']
|
||||
import modules.functional as F
|
||||
from modules.se import SE3d
|
||||
from modules.shared_mlp import SharedMLP
|
||||
from modules.voxelization import Voxelization
|
||||
|
||||
__all__ = ["PVConv", "Attention", "Swish", "PVConvReLU"]
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self,x):
|
||||
return x * torch.sigmoid(x)
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
@ -35,23 +36,19 @@ class Attention(nn.Module):
|
|||
|
||||
self.sm = nn.Softmax(-1)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
B, C = x.shape[:2]
|
||||
h = x
|
||||
|
||||
q = self.q(h).reshape(B, C, -1)
|
||||
k = self.k(h).reshape(B, C, -1)
|
||||
v = self.v(h).reshape(B, C, -1)
|
||||
|
||||
|
||||
|
||||
q = self.q(h).reshape(B,C,-1)
|
||||
k = self.k(h).reshape(B,C,-1)
|
||||
v = self.v(h).reshape(B,C,-1)
|
||||
|
||||
qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5))
|
||||
qk = torch.matmul(q.permute(0, 2, 1), k) # * (int(C) ** (-0.5))
|
||||
|
||||
w = self.sm(qk)
|
||||
|
||||
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:])
|
||||
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B, C, *x.shape[2:])
|
||||
|
||||
h = self.out(h)
|
||||
|
||||
|
@ -61,9 +58,21 @@ class Attention(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
|
||||
class PVConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False,
|
||||
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
resolution,
|
||||
attention=False,
|
||||
dropout=0.1,
|
||||
with_se=False,
|
||||
with_se_relu=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -74,13 +83,13 @@ class PVConv(nn.Module):
|
|||
voxel_layers = [
|
||||
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
||||
Swish()
|
||||
Swish(),
|
||||
]
|
||||
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
|
||||
voxel_layers += [
|
||||
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
||||
Attention(out_channels, 8) if attention else Swish()
|
||||
Attention(out_channels, 8) if attention else Swish(),
|
||||
]
|
||||
if with_se:
|
||||
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
|
||||
|
@ -96,10 +105,21 @@ class PVConv(nn.Module):
|
|||
return fused_features, coords, temb
|
||||
|
||||
|
||||
|
||||
class PVConvReLU(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2,
|
||||
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
resolution,
|
||||
attention=False,
|
||||
leak=0.2,
|
||||
dropout=0.1,
|
||||
with_se=False,
|
||||
with_se_relu=False,
|
||||
normalize=True,
|
||||
eps=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -110,13 +130,13 @@ class PVConvReLU(nn.Module):
|
|||
voxel_layers = [
|
||||
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||
nn.BatchNorm3d(out_channels),
|
||||
nn.LeakyReLU(leak, True)
|
||||
nn.LeakyReLU(leak, True),
|
||||
]
|
||||
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
|
||||
voxel_layers += [
|
||||
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||
nn.BatchNorm3d(out_channels),
|
||||
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True)
|
||||
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True),
|
||||
]
|
||||
if with_se:
|
||||
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
__all__ = ['SE3d']
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ["SE3d"]
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self,x):
|
||||
return x * torch.sigmoid(x)
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SE3d(nn.Module):
|
||||
def __init__(self, channel, reduction=8, use_relu=False):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(True) if use_relu else Swish() ,
|
||||
nn.ReLU(True) if use_relu else Swish(),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
nn.Sigmoid()
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['SharedMLP']
|
||||
__all__ = ["SharedMLP"]
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self,x):
|
||||
return x * torch.sigmoid(x)
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SharedMLP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dim=1):
|
||||
|
@ -23,11 +24,13 @@ class SharedMLP(nn.Module):
|
|||
out_channels = [out_channels]
|
||||
layers = []
|
||||
for oc in out_channels:
|
||||
layers.extend([
|
||||
conv(in_channels, oc, 1),
|
||||
bn(8, oc),
|
||||
Swish(),
|
||||
])
|
||||
layers.extend(
|
||||
[
|
||||
conv(in_channels, oc, 1),
|
||||
bn(8, oc),
|
||||
Swish(),
|
||||
]
|
||||
)
|
||||
in_channels = oc
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
|
||||
import modules.functional as F
|
||||
|
||||
__all__ = ['Voxelization']
|
||||
__all__ = ["Voxelization"]
|
||||
|
||||
|
||||
class Voxelization(nn.Module):
|
||||
|
@ -17,7 +17,10 @@ class Voxelization(nn.Module):
|
|||
coords = coords.detach()
|
||||
norm_coords = coords - coords.mean(2, keepdim=True)
|
||||
if self.normalize:
|
||||
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
|
||||
norm_coords = (
|
||||
norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps)
|
||||
+ 0.5
|
||||
)
|
||||
else:
|
||||
norm_coords = (norm_coords + 1) / 2.0
|
||||
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
||||
|
@ -25,4 +28,4 @@ class Voxelization(nn.Module):
|
|||
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
|
||||
|
||||
def extra_repr(self):
|
||||
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
|
||||
return "resolution={}{}".format(self.r, ", normalized eps = {}".format(self.eps) if self.normalize else "")
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
|
||||
import argparse
|
||||
from torch.distributions import Normal
|
||||
from utils.file_utils import *
|
||||
from model.pvcnn_completion import PVCNN2Base
|
||||
|
||||
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
|
||||
from datasets.shapenet_data_sv import *
|
||||
'''
|
||||
from metrics.evaluation_metrics import EMD_CD, compute_all_metrics
|
||||
from model.pvcnn_completion import PVCNN2Base
|
||||
from utils.file_utils import *
|
||||
|
||||
"""
|
||||
models
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
KL divergence between normal distributions parameterized by mean and log-variance.
|
||||
"""
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
||||
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||
# Assumes data is integers [0, 1]
|
||||
|
@ -31,21 +32,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
|||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 0.5)
|
||||
cdf_plus = px0.cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - .5)
|
||||
min_in = inv_stdv * (centered_x - 0.5)
|
||||
cdf_min = px0.cdf(min_in)
|
||||
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
|
||||
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
|
||||
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
||||
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(
|
||||
x < 0.001, log_cdf_plus,
|
||||
torch.where(x > 0.999, log_one_minus_cdf_min,
|
||||
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
|
||||
x < 0.001,
|
||||
log_cdf_plus,
|
||||
torch.where(
|
||||
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
|
||||
),
|
||||
)
|
||||
assert log_probs.shape == x.shape
|
||||
return log_probs
|
||||
|
||||
|
||||
|
||||
class GaussianDiffusion:
|
||||
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
|
||||
self.loss_type = loss_type
|
||||
|
@ -54,15 +57,15 @@ class GaussianDiffusion:
|
|||
assert isinstance(betas, np.ndarray)
|
||||
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
||||
assert (betas > 0).all() and (betas <= 1).all()
|
||||
timesteps, = betas.shape
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.sv_points = sv_points
|
||||
# initialize twice the actual length so we can keep running for eval
|
||||
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
||||
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
|
||||
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
|
||||
|
||||
self.betas = torch.from_numpy(betas).float()
|
||||
self.alphas_cumprod = alphas_cumprod.float()
|
||||
|
@ -70,21 +73,23 @@ class GaussianDiffusion:
|
|||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
|
||||
|
||||
betas = torch.from_numpy(betas).float()
|
||||
alphas = torch.from_numpy(alphas).float()
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.posterior_variance = posterior_variance
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
|
||||
)
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
|
||||
@staticmethod
|
||||
def _extract(a, t, x_shape):
|
||||
|
@ -92,17 +97,15 @@ class GaussianDiffusion:
|
|||
Extract some coefficients at specified timesteps,
|
||||
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
||||
"""
|
||||
bs, = t.shape
|
||||
(bs,) = t.shape
|
||||
assert x_shape[0] == bs
|
||||
out = torch.gather(a, 0, t)
|
||||
assert out.shape == torch.Size([bs])
|
||||
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
||||
|
||||
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
|
@ -114,54 +117,59 @@ class GaussianDiffusion:
|
|||
noise = torch.randn(x_start.shape, device=x_start.device)
|
||||
assert noise.shape == x_start.shape
|
||||
return (
|
||||
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
||||
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
||||
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
|
||||
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
||||
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
|
||||
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
||||
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
|
||||
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
|
||||
x_start.shape[0])
|
||||
posterior_log_variance_clipped = self._extract(
|
||||
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
|
||||
)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
|
||||
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
||||
model_output = denoise_fn(data, t)[:, :, self.sv_points :]
|
||||
|
||||
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
|
||||
|
||||
|
||||
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
|
||||
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
|
||||
# below: only log_variance is used in the KL computations
|
||||
model_variance, model_log_variance = {
|
||||
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
||||
'fixedlarge': (self.betas.to(data.device),
|
||||
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
|
||||
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
|
||||
"fixedlarge": (
|
||||
self.betas.to(data.device),
|
||||
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
|
||||
),
|
||||
"fixedsmall": (
|
||||
self.posterior_variance.to(data.device),
|
||||
self.posterior_log_variance_clipped.to(data.device),
|
||||
),
|
||||
}[self.model_var_type]
|
||||
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
|
||||
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
|
||||
else:
|
||||
raise NotImplementedError(self.model_var_type)
|
||||
|
||||
if self.model_mean_type == 'eps':
|
||||
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
|
||||
if self.model_mean_type == "eps":
|
||||
x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output)
|
||||
|
||||
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t)
|
||||
else:
|
||||
raise NotImplementedError(self.loss_type)
|
||||
|
||||
|
||||
assert model_mean.shape == x_recon.shape
|
||||
assert model_variance.shape == model_log_variance.shape
|
||||
if return_pred_xstart:
|
||||
|
@ -172,30 +180,31 @@ class GaussianDiffusion:
|
|||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
|
||||
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
||||
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
|
||||
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
''' samples '''
|
||||
""" samples """
|
||||
|
||||
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
|
||||
"""
|
||||
Sample from the model
|
||||
"""
|
||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
|
||||
return_pred_xstart=True)
|
||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
||||
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
||||
)
|
||||
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
|
||||
|
||||
# no noise when t == 0
|
||||
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
|
||||
|
||||
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
|
||||
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
|
||||
sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1)
|
||||
return (sample, pred_xstart) if return_pred_xstart else sample
|
||||
|
||||
|
||||
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
|
||||
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
||||
def p_sample_loop(
|
||||
self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False
|
||||
):
|
||||
"""
|
||||
Generate samples
|
||||
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
||||
|
@ -206,14 +215,21 @@ class GaussianDiffusion:
|
|||
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
|
||||
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
|
||||
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised, return_pred_xstart=False)
|
||||
img_t = self.p_sample(
|
||||
denoise_fn=denoise_fn,
|
||||
data=img_t,
|
||||
t=t_,
|
||||
noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
return_pred_xstart=False,
|
||||
)
|
||||
|
||||
assert img_t[:,:,self.sv_points:].shape == shape
|
||||
assert img_t[:, :, self.sv_points :].shape == shape
|
||||
return img_t
|
||||
|
||||
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
|
||||
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
|
||||
def p_sample_loop_trajectory(
|
||||
self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False
|
||||
):
|
||||
"""
|
||||
Generate samples, returning intermediate images
|
||||
Useful for visualizing how denoised images evolve over time
|
||||
|
@ -223,31 +239,38 @@ class GaussianDiffusion:
|
|||
"""
|
||||
assert isinstance(shape, (tuple, list))
|
||||
|
||||
total_steps = self.num_timesteps if not keep_running else len(self.betas)
|
||||
total_steps = self.num_timesteps if not keep_running else len(self.betas)
|
||||
|
||||
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
|
||||
imgs = [img_t]
|
||||
for t in reversed(range(0,total_steps)):
|
||||
|
||||
for t in reversed(range(0, total_steps)):
|
||||
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
return_pred_xstart=False)
|
||||
if t % freq == 0 or t == total_steps-1:
|
||||
img_t = self.p_sample(
|
||||
denoise_fn=denoise_fn,
|
||||
data=img_t,
|
||||
t=t_,
|
||||
noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
return_pred_xstart=False,
|
||||
)
|
||||
if t % freq == 0 or t == total_steps - 1:
|
||||
imgs.append(img_t)
|
||||
|
||||
assert imgs[-1].shape == shape
|
||||
return imgs
|
||||
|
||||
'''losses'''
|
||||
"""losses"""
|
||||
|
||||
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
||||
x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t
|
||||
)
|
||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
||||
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
|
||||
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
||||
)
|
||||
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
|
||||
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
|
||||
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0)
|
||||
|
||||
return (kl, pred_xstart) if return_pred_xstart else kl
|
||||
|
||||
|
@ -259,66 +282,87 @@ class GaussianDiffusion:
|
|||
assert t.shape == torch.Size([B])
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
|
||||
noise = torch.randn(
|
||||
data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device
|
||||
)
|
||||
|
||||
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
|
||||
data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise)
|
||||
|
||||
if self.loss_type == 'mse':
|
||||
if self.loss_type == "mse":
|
||||
# predict the noise instead of x_start. seems to be weighted naturally like SNR
|
||||
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
|
||||
eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[
|
||||
:, :, self.sv_points :
|
||||
]
|
||||
|
||||
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
|
||||
elif self.loss_type == 'kl':
|
||||
losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape))))
|
||||
elif self.loss_type == "kl":
|
||||
losses = self._vb_terms_bpd(
|
||||
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
|
||||
return_pred_xstart=False)
|
||||
denoise_fn=denoise_fn,
|
||||
data_start=data_start,
|
||||
data_t=data_t,
|
||||
t=t,
|
||||
clip_denoised=False,
|
||||
return_pred_xstart=False,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(self.loss_type)
|
||||
|
||||
assert losses.shape == torch.Size([B])
|
||||
return losses
|
||||
|
||||
'''debug'''
|
||||
"""debug"""
|
||||
|
||||
def _prior_bpd(self, x_start):
|
||||
|
||||
with torch.no_grad():
|
||||
B, T = x_start.shape[0], self.num_timesteps
|
||||
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
|
||||
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1)
|
||||
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
|
||||
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
|
||||
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
|
||||
kl_prior = normal_kl(
|
||||
mean1=qt_mean,
|
||||
logvar1=qt_log_variance,
|
||||
mean2=torch.tensor([0.0]).to(qt_mean),
|
||||
logvar2=torch.tensor([0.0]).to(qt_log_variance),
|
||||
)
|
||||
assert kl_prior.shape == x_start.shape
|
||||
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
|
||||
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0)
|
||||
|
||||
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
|
||||
|
||||
with torch.no_grad():
|
||||
B, T = x_start.shape[0], self.num_timesteps
|
||||
|
||||
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
|
||||
vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
|
||||
for t in reversed(range(T)):
|
||||
|
||||
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
|
||||
# Calculate VLB term at the current timestep
|
||||
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
|
||||
data_t = torch.cat(
|
||||
[x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)],
|
||||
dim=-1,
|
||||
)
|
||||
new_vals_b, pred_xstart = self._vb_terms_bpd(
|
||||
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
|
||||
clip_denoised=clip_denoised, return_pred_xstart=True)
|
||||
denoise_fn,
|
||||
data_start=x_start,
|
||||
data_t=data_t,
|
||||
t=t_b,
|
||||
clip_denoised=clip_denoised,
|
||||
return_pred_xstart=True,
|
||||
)
|
||||
# MSE for progressive prediction loss
|
||||
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
|
||||
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
|
||||
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
|
||||
assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape
|
||||
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean(
|
||||
dim=list(range(1, len(pred_xstart.shape)))
|
||||
)
|
||||
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
|
||||
# Insert the calculated term into the tensor of all terms
|
||||
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
|
||||
mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float()
|
||||
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
|
||||
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
|
||||
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
|
||||
|
||||
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
|
||||
prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :])
|
||||
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
|
||||
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
|
||||
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
|
||||
assert vals_bt_.shape == mse_bt_.shape == torch.Size(
|
||||
[B, T]
|
||||
) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
|
||||
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
|
||||
|
||||
|
||||
|
@ -336,39 +380,53 @@ class PVCNN2(PVCNN2Base):
|
|||
((128, 128, 64), (64, 2, 32)),
|
||||
]
|
||||
|
||||
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
|
||||
voxel_resolution_multiplier=1):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
sv_points,
|
||||
embed_dim,
|
||||
use_att,
|
||||
dropout,
|
||||
extra_feature_channels=3,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
super().__init__(
|
||||
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
|
||||
dropout=dropout, extra_feature_channels=extra_feature_channels,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
||||
num_classes=num_classes,
|
||||
sv_points=sv_points,
|
||||
embed_dim=embed_dim,
|
||||
use_att=use_att,
|
||||
dropout=dropout,
|
||||
extra_feature_channels=extra_feature_channels,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
|
||||
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
|
||||
super(Model, self).__init__()
|
||||
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
|
||||
|
||||
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
|
||||
dropout=args.dropout, extra_feature_channels=0)
|
||||
self.model = PVCNN2(
|
||||
num_classes=args.nc,
|
||||
sv_points=args.svpoints,
|
||||
embed_dim=args.embed_dim,
|
||||
use_att=args.attention,
|
||||
dropout=args.dropout,
|
||||
extra_feature_channels=0,
|
||||
)
|
||||
|
||||
def prior_kl(self, x0):
|
||||
return self.diffusion._prior_bpd(x0)
|
||||
|
||||
def all_kl(self, x0, clip_denoised=True):
|
||||
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
||||
|
||||
return {
|
||||
'total_bpd_b': total_bpd_b,
|
||||
'terms_bpd': vals_bt,
|
||||
'prior_bpd_b': prior_bpd_b,
|
||||
'mse_bt':mse_bt
|
||||
}
|
||||
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
||||
|
||||
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
|
||||
|
||||
def _denoise(self, data, t):
|
||||
B, D,N= data.shape
|
||||
B, D, N = data.shape
|
||||
assert data.dtype == torch.float
|
||||
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
||||
|
||||
|
@ -381,20 +439,22 @@ class Model(nn.Module):
|
|||
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
||||
|
||||
if noises is not None:
|
||||
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
|
||||
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
|
||||
|
||||
losses = self.diffusion.p_losses(
|
||||
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
||||
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
||||
assert losses.shape == t.shape == torch.Size([B])
|
||||
return losses
|
||||
|
||||
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
|
||||
clip_denoised=True,
|
||||
keep_running=False):
|
||||
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
keep_running=keep_running)
|
||||
|
||||
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
||||
return self.diffusion.p_sample_loop(
|
||||
partial_x,
|
||||
self._denoise,
|
||||
shape=shape,
|
||||
device=device,
|
||||
noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
keep_running=keep_running,
|
||||
)
|
||||
|
||||
def train(self):
|
||||
self.model.train()
|
||||
|
@ -405,21 +465,19 @@ class Model(nn.Module):
|
|||
def multi_gpu_wrapper(self, f):
|
||||
self.model = f(self.model)
|
||||
|
||||
def get_betas(schedule_type, b_start, b_end, time_num):
|
||||
if schedule_type == 'linear':
|
||||
betas = np.linspace(b_start, b_end, time_num)
|
||||
elif schedule_type == 'warm0.1':
|
||||
|
||||
def get_betas(schedule_type, b_start, b_end, time_num):
|
||||
if schedule_type == "linear":
|
||||
betas = np.linspace(b_start, b_end, time_num)
|
||||
elif schedule_type == "warm0.1":
|
||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||
warmup_time = int(time_num * 0.1)
|
||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||
elif schedule_type == 'warm0.2':
|
||||
|
||||
elif schedule_type == "warm0.2":
|
||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||
warmup_time = int(time_num * 0.2)
|
||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||
elif schedule_type == 'warm0.5':
|
||||
|
||||
elif schedule_type == "warm0.5":
|
||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||
warmup_time = int(time_num * 0.5)
|
||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||
|
@ -428,22 +486,29 @@ def get_betas(schedule_type, b_start, b_end, time_num):
|
|||
return betas
|
||||
|
||||
|
||||
|
||||
#############################################################################
|
||||
|
||||
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
|
||||
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
|
||||
categories=[category], split='train',
|
||||
|
||||
def get_mvr_dataset(pc_dataroot, views_root, npoints, category):
|
||||
tr_dataset = ShapeNet15kPointClouds(
|
||||
root_dir=pc_dataroot,
|
||||
categories=[category],
|
||||
split="train",
|
||||
tr_sample_size=npoints,
|
||||
te_sample_size=npoints,
|
||||
scale=1.,
|
||||
scale=1.0,
|
||||
normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
random_subsample=True)
|
||||
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
|
||||
cache=os.path.join(pc_dataroot, '../cache'), split='val',
|
||||
random_subsample=True,
|
||||
)
|
||||
te_dataset = ShapeNet_Multiview_Points(
|
||||
root_pc=pc_dataroot,
|
||||
root_views=views_root,
|
||||
cache=os.path.join(pc_dataroot, "../cache"),
|
||||
split="val",
|
||||
categories=[category],
|
||||
npoints=npoints, sv_samples=200,
|
||||
npoints=npoints,
|
||||
sv_samples=200,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
)
|
||||
|
@ -451,39 +516,41 @@ def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
|
|||
|
||||
|
||||
def evaluate_recon_mvr(opt, model, save_dir, logger):
|
||||
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
|
||||
opt.npoints, opt.category)
|
||||
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.category)
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
|
||||
shuffle=False, num_workers=int(opt.workers), drop_last=False)
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
||||
)
|
||||
ref = []
|
||||
samples = []
|
||||
masked = []
|
||||
k = 0
|
||||
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
|
||||
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Reconstructing Samples"):
|
||||
gt_all = data["test_points"]
|
||||
x_all = data["sv_points"]
|
||||
|
||||
gt_all = data['test_points']
|
||||
x_all = data['sv_points']
|
||||
|
||||
B,V,N,C = x_all.shape
|
||||
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
|
||||
B, V, N, C = x_all.shape
|
||||
gt_all = gt_all[:, None, :, :].expand(-1, V, -1, -1)
|
||||
|
||||
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
|
||||
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
m, s = data["mean"].float(), data["std"].float()
|
||||
|
||||
recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
|
||||
clip_denoised=False).detach().cpu()
|
||||
recon = (
|
||||
model.gen_samples(
|
||||
x[:, :, : opt.svpoints].cuda(), x[:, :, opt.svpoints :].shape, "cuda", clip_denoised=False
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
)
|
||||
|
||||
recon = recon.transpose(1, 2).contiguous()
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
|
||||
x_adj = x.reshape(B, V, N, C) * s + m
|
||||
recon_adj = recon.reshape(B, V, N, C) * s + m
|
||||
|
||||
x_adj = x.reshape(B,V,N,C)* s + m
|
||||
recon_adj = recon.reshape(B,V,N,C)* s + m
|
||||
|
||||
ref.append( gt_all * s + m)
|
||||
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
|
||||
ref.append(gt_all * s + m)
|
||||
masked.append(x_adj[:, :, : test_dataloader.dataset.sv_samples, :])
|
||||
samples.append(recon_adj)
|
||||
|
||||
ref_pcs = torch.cat(ref, dim=0)
|
||||
|
@ -492,31 +559,40 @@ def evaluate_recon_mvr(opt, model, save_dir, logger):
|
|||
|
||||
B, V, N, C = ref_pcs.shape
|
||||
|
||||
torch.save(ref_pcs.reshape(B, V, N, C), os.path.join(save_dir, "recon_gt.pth"))
|
||||
|
||||
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
|
||||
|
||||
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
|
||||
torch.save(masked.reshape(B, V, *masked.shape[2:]), os.path.join(save_dir, "recon_masked.pth"))
|
||||
# Compute metrics
|
||||
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
|
||||
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
|
||||
results = EMD_CD(sample_pcs.reshape(B * V, N, C), ref_pcs.reshape(B * V, N, C), opt.batch_size, reduced=False)
|
||||
|
||||
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
|
||||
results = {
|
||||
ky: val.reshape(B, V)
|
||||
if val.shape
|
||||
== torch.Size(
|
||||
[
|
||||
B * V,
|
||||
]
|
||||
)
|
||||
else val
|
||||
for ky, val in results.items()
|
||||
}
|
||||
|
||||
pprint({key: val.mean().item() for key, val in results.items()})
|
||||
logger.info({key: val.mean().item() for key, val in results.items()})
|
||||
|
||||
results['pc'] = sample_pcs
|
||||
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
|
||||
results["pc"] = sample_pcs
|
||||
torch.save(results, os.path.join(save_dir, "ours_results.pth"))
|
||||
|
||||
del ref_pcs, masked, results
|
||||
|
||||
|
||||
def evaluate_saved(opt, saved_dir):
|
||||
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
|
||||
|
||||
gt_pth = saved_dir + '/recon_gt.pth'
|
||||
ours_pth = saved_dir + '/ours_results.pth'
|
||||
gt = torch.load(gt_pth).permute(1,0,2,3)
|
||||
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
|
||||
gt_pth = saved_dir + "/recon_gt.pth"
|
||||
ours_pth = saved_dir + "/ours_results.pth"
|
||||
gt = torch.load(gt_pth).permute(1, 0, 2, 3)
|
||||
ours = torch.load(ours_pth)["pc"].permute(1, 0, 2, 3)
|
||||
|
||||
all_res = {}
|
||||
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
|
||||
|
@ -534,7 +610,6 @@ def evaluate_saved(opt, saved_dir):
|
|||
pprint({key: val.mean().item() for key, val in all_res.items()})
|
||||
|
||||
|
||||
|
||||
def main(opt):
|
||||
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
||||
dir_id = os.path.dirname(__file__)
|
||||
|
@ -542,7 +617,7 @@ def main(opt):
|
|||
copy_source(__file__, output_dir)
|
||||
logger = setup_logging(output_dir)
|
||||
|
||||
outf_syn, = setup_output_subdirs(output_dir, 'syn')
|
||||
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
|
||||
|
||||
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
||||
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
||||
|
@ -559,12 +634,10 @@ def main(opt):
|
|||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
logger.info("Resume Path:%s" % opt.model)
|
||||
|
||||
resumed_param = torch.load(opt.model)
|
||||
model.load_state_dict(resumed_param['model_state'])
|
||||
|
||||
model.load_state_dict(resumed_param["model_state"])
|
||||
|
||||
if opt.eval_recon_mvr:
|
||||
# Evaluate generation
|
||||
|
@ -575,47 +648,44 @@ def main(opt):
|
|||
|
||||
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
|
||||
parser.add_argument('--dataroot_sv', default='GenReData/')
|
||||
parser.add_argument('--category', default='chair')
|
||||
parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/")
|
||||
parser.add_argument("--dataroot_sv", default="GenReData/")
|
||||
parser.add_argument("--category", default="chair")
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
|
||||
parser.add_argument('--workers', type=int, default=16, help='workers')
|
||||
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
|
||||
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
|
||||
parser.add_argument("--workers", type=int, default=16, help="workers")
|
||||
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
|
||||
|
||||
parser.add_argument('--eval_recon_mvr', default=True)
|
||||
parser.add_argument('--eval_saved', default=True)
|
||||
parser.add_argument("--eval_recon_mvr", default=True)
|
||||
parser.add_argument("--eval_saved", default=True)
|
||||
|
||||
parser.add_argument('--nc', default=3)
|
||||
parser.add_argument('--npoints', default=2048)
|
||||
parser.add_argument('--svpoints', default=200)
|
||||
'''model'''
|
||||
parser.add_argument('--beta_start', default=0.0001)
|
||||
parser.add_argument('--beta_end', default=0.02)
|
||||
parser.add_argument('--schedule_type', default='linear')
|
||||
parser.add_argument('--time_num', default=1000)
|
||||
parser.add_argument("--nc", default=3)
|
||||
parser.add_argument("--npoints", default=2048)
|
||||
parser.add_argument("--svpoints", default=200)
|
||||
"""model"""
|
||||
parser.add_argument("--beta_start", default=0.0001)
|
||||
parser.add_argument("--beta_end", default=0.02)
|
||||
parser.add_argument("--schedule_type", default="linear")
|
||||
parser.add_argument("--time_num", default=1000)
|
||||
|
||||
#params
|
||||
parser.add_argument('--attention', default=True)
|
||||
parser.add_argument('--dropout', default=0.1)
|
||||
parser.add_argument('--embed_dim', type=int, default=64)
|
||||
parser.add_argument('--loss_type', default='mse')
|
||||
parser.add_argument('--model_mean_type', default='eps')
|
||||
parser.add_argument('--model_var_type', default='fixedsmall')
|
||||
# params
|
||||
parser.add_argument("--attention", default=True)
|
||||
parser.add_argument("--dropout", default=0.1)
|
||||
parser.add_argument("--embed_dim", type=int, default=64)
|
||||
parser.add_argument("--loss_type", default="mse")
|
||||
parser.add_argument("--model_mean_type", default="eps")
|
||||
parser.add_argument("--model_var_type", default="fixedsmall")
|
||||
|
||||
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
|
||||
|
||||
parser.add_argument('--model', default='', required=True, help="path to model (to continue training)")
|
||||
"""eval"""
|
||||
|
||||
'''eval'''
|
||||
parser.add_argument("--eval_path", default="")
|
||||
|
||||
parser.add_argument('--eval_path',
|
||||
default='')
|
||||
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
|
||||
|
||||
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
|
||||
|
||||
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
|
||||
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
@ -625,7 +695,9 @@ def parse_args():
|
|||
opt.cuda = False
|
||||
|
||||
return opt
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_args()
|
||||
|
||||
main(opt)
|
||||
|
|
|
@ -1,31 +1,30 @@
|
|||
import torch
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
||||
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
|
||||
import argparse
|
||||
from torch.distributions import Normal
|
||||
|
||||
from utils.file_utils import *
|
||||
from utils.visualize import *
|
||||
from model.pvcnn_generation import PVCNN2Base
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
|
||||
from metrics.evaluation_metrics import compute_all_metrics
|
||||
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
||||
from model.pvcnn_generation import PVCNN2Base
|
||||
from utils.file_utils import *
|
||||
from utils.visualize import *
|
||||
|
||||
'''
|
||||
"""
|
||||
models
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
KL divergence between normal distributions parameterized by mean and log-variance.
|
||||
"""
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
||||
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||
# Assumes data is integers [0, 1]
|
||||
|
@ -36,37 +35,40 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
|||
inv_stdv = torch.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 0.5)
|
||||
cdf_plus = px0.cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - .5)
|
||||
min_in = inv_stdv * (centered_x - 0.5)
|
||||
cdf_min = px0.cdf(min_in)
|
||||
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
|
||||
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
|
||||
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
||||
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
|
||||
log_probs = torch.where(
|
||||
x < 0.001, log_cdf_plus,
|
||||
torch.where(x > 0.999, log_one_minus_cdf_min,
|
||||
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
|
||||
x < 0.001,
|
||||
log_cdf_plus,
|
||||
torch.where(
|
||||
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
|
||||
),
|
||||
)
|
||||
assert log_probs.shape == x.shape
|
||||
return log_probs
|
||||
|
||||
|
||||
class GaussianDiffusion:
|
||||
def __init__(self,betas, loss_type, model_mean_type, model_var_type):
|
||||
def __init__(self, betas, loss_type, model_mean_type, model_var_type):
|
||||
self.loss_type = loss_type
|
||||
self.model_mean_type = model_mean_type
|
||||
self.model_var_type = model_var_type
|
||||
assert isinstance(betas, np.ndarray)
|
||||
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
||||
assert (betas > 0).all() and (betas <= 1).all()
|
||||
timesteps, = betas.shape
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
|
||||
# initialize twice the actual length so we can keep running for eval
|
||||
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
||||
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
|
||||
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
|
||||
|
||||
self.betas = torch.from_numpy(betas).float()
|
||||
self.alphas_cumprod = alphas_cumprod.float()
|
||||
|
@ -74,21 +76,23 @@ class GaussianDiffusion:
|
|||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
|
||||
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
|
||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
|
||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
|
||||
|
||||
betas = torch.from_numpy(betas).float()
|
||||
alphas = torch.from_numpy(alphas).float()
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
self.posterior_variance = posterior_variance
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
|
||||
self.posterior_log_variance_clipped = torch.log(
|
||||
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
|
||||
)
|
||||
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
|
||||
@staticmethod
|
||||
def _extract(a, t, x_shape):
|
||||
|
@ -96,17 +100,15 @@ class GaussianDiffusion:
|
|||
Extract some coefficients at specified timesteps,
|
||||
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
||||
"""
|
||||
bs, = t.shape
|
||||
(bs,) = t.shape
|
||||
assert x_shape[0] == bs
|
||||
out = torch.gather(a, 0, t)
|
||||
assert out.shape == torch.Size([bs])
|
||||
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
||||
|
||||
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
|
@ -118,56 +120,62 @@ class GaussianDiffusion:
|
|||
noise = torch.randn(x_start.shape, device=x_start.device)
|
||||
assert noise.shape == x_start.shape
|
||||
return (
|
||||
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
||||
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
||||
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
|
||||
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
||||
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
|
||||
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
||||
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
|
||||
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
|
||||
x_start.shape[0])
|
||||
posterior_log_variance_clipped = self._extract(
|
||||
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
|
||||
)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
|
||||
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
||||
|
||||
model_output = denoise_fn(data, t)
|
||||
|
||||
|
||||
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
|
||||
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
|
||||
# below: only log_variance is used in the KL computations
|
||||
model_variance, model_log_variance = {
|
||||
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
||||
'fixedlarge': (self.betas.to(data.device),
|
||||
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
|
||||
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
|
||||
"fixedlarge": (
|
||||
self.betas.to(data.device),
|
||||
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
|
||||
),
|
||||
"fixedsmall": (
|
||||
self.posterior_variance.to(data.device),
|
||||
self.posterior_log_variance_clipped.to(data.device),
|
||||
),
|
||||
}[self.model_var_type]
|
||||
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
|
||||
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
|
||||
else:
|
||||
raise NotImplementedError(self.model_var_type)
|
||||
|
||||
if self.model_mean_type == 'eps':
|
||||
if self.model_mean_type == "eps":
|
||||
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
|
||||
|
||||
if clip_denoised:
|
||||
x_recon = torch.clamp(x_recon, -.5, .5)
|
||||
x_recon = torch.clamp(x_recon, -0.5, 0.5)
|
||||
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
|
||||
else:
|
||||
raise NotImplementedError(self.loss_type)
|
||||
|
||||
|
||||
assert model_mean.shape == x_recon.shape == data.shape
|
||||
assert model_variance.shape == model_log_variance.shape == data.shape
|
||||
if return_pred_xstart:
|
||||
|
@ -178,18 +186,19 @@ class GaussianDiffusion:
|
|||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
|
||||
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
||||
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
|
||||
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
''' samples '''
|
||||
""" samples """
|
||||
|
||||
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
|
||||
"""
|
||||
Sample from the model
|
||||
"""
|
||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
|
||||
return_pred_xstart=True)
|
||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
||||
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
||||
)
|
||||
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
|
||||
assert noise.shape == data.shape
|
||||
# no noise when t == 0
|
||||
|
@ -201,10 +210,17 @@ class GaussianDiffusion:
|
|||
assert sample.shape == pred_xstart.shape
|
||||
return (sample, pred_xstart) if return_pred_xstart else sample
|
||||
|
||||
|
||||
def p_sample_loop(self, denoise_fn, shape, device,
|
||||
noise_fn=torch.randn, constrain_fn=lambda x, t:x,
|
||||
clip_denoised=True, max_timestep=None, keep_running=False):
|
||||
def p_sample_loop(
|
||||
self,
|
||||
denoise_fn,
|
||||
shape,
|
||||
device,
|
||||
noise_fn=torch.randn,
|
||||
constrain_fn=lambda x, t: x,
|
||||
clip_denoised=True,
|
||||
max_timestep=None,
|
||||
keep_running=False,
|
||||
):
|
||||
"""
|
||||
Generate samples
|
||||
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
||||
|
@ -220,28 +236,38 @@ class GaussianDiffusion:
|
|||
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
|
||||
img_t = constrain_fn(img_t, t)
|
||||
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised, return_pred_xstart=False).detach()
|
||||
|
||||
img_t = self.p_sample(
|
||||
denoise_fn=denoise_fn,
|
||||
data=img_t,
|
||||
t=t_,
|
||||
noise_fn=noise_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
return_pred_xstart=False,
|
||||
).detach()
|
||||
|
||||
assert img_t.shape == shape
|
||||
return img_t
|
||||
|
||||
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
|
||||
|
||||
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t: x):
|
||||
assert t >= 1
|
||||
|
||||
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
|
||||
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t - 1)
|
||||
encoding = self.q_sample(x0, t_vec)
|
||||
|
||||
img_t = encoding
|
||||
|
||||
for k in reversed(range(0,t)):
|
||||
for k in reversed(range(0, t)):
|
||||
img_t = constrain_fn(img_t, k)
|
||||
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
|
||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
|
||||
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
|
||||
|
||||
img_t = self.p_sample(
|
||||
denoise_fn=denoise_fn,
|
||||
data=img_t,
|
||||
t=t_,
|
||||
noise_fn=noise_fn,
|
||||
clip_denoised=False,
|
||||
return_pred_xstart=False,
|
||||
use_var=True,
|
||||
).detach()
|
||||
|
||||
return img_t
|
||||
|
||||
|
@ -260,40 +286,50 @@ class PVCNN2(PVCNN2Base):
|
|||
((128, 128, 64), (64, 2, 32)),
|
||||
]
|
||||
|
||||
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
|
||||
voxel_resolution_multiplier=1):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
embed_dim,
|
||||
use_att,
|
||||
dropout,
|
||||
extra_feature_channels=3,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
):
|
||||
super().__init__(
|
||||
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
|
||||
dropout=dropout, extra_feature_channels=extra_feature_channels,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
||||
num_classes=num_classes,
|
||||
embed_dim=embed_dim,
|
||||
use_att=use_att,
|
||||
dropout=dropout,
|
||||
extra_feature_channels=extra_feature_channels,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
|
||||
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
|
||||
super(Model, self).__init__()
|
||||
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
|
||||
|
||||
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
|
||||
dropout=args.dropout, extra_feature_channels=0)
|
||||
self.model = PVCNN2(
|
||||
num_classes=args.nc,
|
||||
embed_dim=args.embed_dim,
|
||||
use_att=args.attention,
|
||||
dropout=args.dropout,
|
||||
extra_feature_channels=0,
|
||||
)
|
||||
|
||||
def prior_kl(self, x0):
|
||||
return self.diffusion._prior_bpd(x0)
|
||||
|
||||
def all_kl(self, x0, clip_denoised=True):
|
||||
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
||||
|
||||
return {
|
||||
'total_bpd_b': total_bpd_b,
|
||||
'terms_bpd': vals_bt,
|
||||
'prior_bpd_b': prior_bpd_b,
|
||||
'mse_bt':mse_bt
|
||||
}
|
||||
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
||||
|
||||
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
|
||||
|
||||
def _denoise(self, data, t):
|
||||
B, D,N= data.shape
|
||||
B, D, N = data.shape
|
||||
assert data.dtype == torch.float
|
||||
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
||||
|
||||
|
@ -307,23 +343,34 @@ class Model(nn.Module):
|
|||
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
||||
|
||||
if noises is not None:
|
||||
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
|
||||
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
|
||||
|
||||
losses = self.diffusion.p_losses(
|
||||
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
||||
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
||||
assert losses.shape == t.shape == torch.Size([B])
|
||||
return losses
|
||||
|
||||
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
|
||||
clip_denoised=False, max_timestep=None,
|
||||
keep_running=False):
|
||||
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
|
||||
constrain_fn=constrain_fn,
|
||||
clip_denoised=clip_denoised, max_timestep=max_timestep,
|
||||
keep_running=keep_running)
|
||||
|
||||
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
|
||||
def gen_samples(
|
||||
self,
|
||||
shape,
|
||||
device,
|
||||
noise_fn=torch.randn,
|
||||
constrain_fn=lambda x, t: x,
|
||||
clip_denoised=False,
|
||||
max_timestep=None,
|
||||
keep_running=False,
|
||||
):
|
||||
return self.diffusion.p_sample_loop(
|
||||
self._denoise,
|
||||
shape=shape,
|
||||
device=device,
|
||||
noise_fn=noise_fn,
|
||||
constrain_fn=constrain_fn,
|
||||
clip_denoised=clip_denoised,
|
||||
max_timestep=max_timestep,
|
||||
keep_running=keep_running,
|
||||
)
|
||||
|
||||
def reconstruct(self, x0, t, constrain_fn=lambda x, t: x):
|
||||
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
|
||||
|
||||
def train(self):
|
||||
|
@ -337,20 +384,17 @@ class Model(nn.Module):
|
|||
|
||||
|
||||
def get_betas(schedule_type, b_start, b_end, time_num):
|
||||
if schedule_type == 'linear':
|
||||
if schedule_type == "linear":
|
||||
betas = np.linspace(b_start, b_end, time_num)
|
||||
elif schedule_type == 'warm0.1':
|
||||
|
||||
elif schedule_type == "warm0.1":
|
||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||
warmup_time = int(time_num * 0.1)
|
||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||
elif schedule_type == 'warm0.2':
|
||||
|
||||
elif schedule_type == "warm0.2":
|
||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||
warmup_time = int(time_num * 0.2)
|
||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||
elif schedule_type == 'warm0.5':
|
||||
|
||||
elif schedule_type == "warm0.5":
|
||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||
warmup_time = int(time_num * 0.5)
|
||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||
|
@ -358,111 +402,109 @@ def get_betas(schedule_type, b_start, b_end, time_num):
|
|||
raise NotImplementedError(schedule_type)
|
||||
return betas
|
||||
|
||||
|
||||
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
|
||||
'''
|
||||
"""
|
||||
|
||||
:param target_shape_constraint: target voxels
|
||||
:return: constrained x
|
||||
'''
|
||||
"""
|
||||
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
|
||||
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
|
||||
def constrain_fn(x, t):
|
||||
eps_ = eps_all[t] if (t<1000) else 0
|
||||
for _ in range(num_steps):
|
||||
x = x - eps_ * ((x - ground_truth) * mask)
|
||||
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000) ** 2))
|
||||
|
||||
def constrain_fn(x, t):
|
||||
eps_ = eps_all[t] if (t < 1000) else 0
|
||||
for _ in range(num_steps):
|
||||
x = x - eps_ * ((x - ground_truth) * mask)
|
||||
|
||||
return x
|
||||
|
||||
return constrain_fn
|
||||
|
||||
|
||||
#############################################################################
|
||||
|
||||
def get_dataset(dataroot, npoints,category,use_mask=False):
|
||||
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
|
||||
categories=[category], split='train',
|
||||
|
||||
def get_dataset(dataroot, npoints, category, use_mask=False):
|
||||
tr_dataset = ShapeNet15kPointClouds(
|
||||
root_dir=dataroot,
|
||||
categories=[category],
|
||||
split="train",
|
||||
tr_sample_size=npoints,
|
||||
te_sample_size=npoints,
|
||||
scale=1.,
|
||||
scale=1.0,
|
||||
normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
random_subsample=True, use_mask = use_mask)
|
||||
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
|
||||
categories=[category], split='val',
|
||||
random_subsample=True,
|
||||
use_mask=use_mask,
|
||||
)
|
||||
te_dataset = ShapeNet15kPointClouds(
|
||||
root_dir=dataroot,
|
||||
categories=[category],
|
||||
split="val",
|
||||
tr_sample_size=npoints,
|
||||
te_sample_size=npoints,
|
||||
scale=1.,
|
||||
scale=1.0,
|
||||
normalize_per_shape=False,
|
||||
normalize_std_per_axis=False,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
use_mask=use_mask
|
||||
use_mask=use_mask,
|
||||
)
|
||||
return tr_dataset, te_dataset
|
||||
|
||||
|
||||
|
||||
def evaluate_gen(opt, ref_pcs, logger):
|
||||
|
||||
if ref_pcs is None:
|
||||
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False)
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
|
||||
shuffle=False, num_workers=int(opt.workers), drop_last=False)
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
||||
)
|
||||
ref = []
|
||||
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
|
||||
x = data['test_points']
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
|
||||
x = data["test_points"]
|
||||
m, s = data["mean"].float(), data["std"].float()
|
||||
|
||||
ref.append(x*s + m)
|
||||
ref.append(x * s + m)
|
||||
|
||||
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
||||
|
||||
logger.info("Loading sample path: %s"
|
||||
% (opt.eval_path))
|
||||
logger.info("Loading sample path: %s" % (opt.eval_path))
|
||||
sample_pcs = torch.load(opt.eval_path).contiguous()
|
||||
|
||||
logger.info("Generation sample size:%s reference size: %s"
|
||||
% (sample_pcs.size(), ref_pcs.size()))
|
||||
|
||||
logger.info("Generation sample size:%s reference size: %s" % (sample_pcs.size(), ref_pcs.size()))
|
||||
|
||||
# Compute metrics
|
||||
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
|
||||
results = {k: (v.cpu().detach().item()
|
||||
if not isinstance(v, float) else v) for k, v in results.items()}
|
||||
results = {k: (v.cpu().detach().item() if not isinstance(v, float) else v) for k, v in results.items()}
|
||||
|
||||
pprint(results)
|
||||
logger.info(results)
|
||||
|
||||
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
|
||||
pprint('JSD: {}'.format(jsd))
|
||||
logger.info('JSD: {}'.format(jsd))
|
||||
|
||||
pprint("JSD: {}".format(jsd))
|
||||
logger.info("JSD: {}".format(jsd))
|
||||
|
||||
|
||||
def generate(model, opt):
|
||||
|
||||
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
|
||||
shuffle=False, num_workers=int(opt.workers), drop_last=False)
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
samples = []
|
||||
ref = []
|
||||
|
||||
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
|
||||
|
||||
x = data['test_points'].transpose(1,2)
|
||||
m, s = data['mean'].float(), data['std'].float()
|
||||
|
||||
gen = model.gen_samples(x.shape,
|
||||
'cuda', clip_denoised=False).detach().cpu()
|
||||
|
||||
gen = gen.transpose(1,2).contiguous()
|
||||
x = x.transpose(1,2).contiguous()
|
||||
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Generating Samples"):
|
||||
x = data["test_points"].transpose(1, 2)
|
||||
m, s = data["mean"].float(), data["std"].float()
|
||||
|
||||
gen = model.gen_samples(x.shape, "cuda", clip_denoised=False).detach().cpu()
|
||||
|
||||
gen = gen.transpose(1, 2).contiguous()
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
|
||||
gen = gen * s + m
|
||||
x = x * s + m
|
||||
|
@ -482,20 +524,20 @@ def generate(model, opt):
|
|||
# 1,
|
||||
# 0.5,
|
||||
# )
|
||||
|
||||
|
||||
# visualize using matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib
|
||||
matplotlib.use('TkAgg')
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
matplotlib.use("TkAgg")
|
||||
for idx, pc in enumerate(gen[:64]):
|
||||
print(f"Visualizing point cloud {idx}...")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,2], cmap=cm.jet)
|
||||
ax.set_aspect('equal')
|
||||
ax.axis('off')
|
||||
ax = fig.add_subplot(111, projection="3d")
|
||||
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
# ax.set_xlim(-1, 1)
|
||||
# ax.set_ylim(-1, 1)
|
||||
# ax.set_zlim(-1, 1)
|
||||
|
@ -507,17 +549,14 @@ def generate(model, opt):
|
|||
|
||||
torch.save(samples, opt.eval_path)
|
||||
|
||||
|
||||
|
||||
return ref
|
||||
|
||||
|
||||
def main(opt):
|
||||
|
||||
if opt.category == 'airplane':
|
||||
if opt.category == "airplane":
|
||||
opt.beta_start = 1e-5
|
||||
opt.beta_end = 0.008
|
||||
opt.schedule_type = 'warm0.1'
|
||||
opt.schedule_type = "warm0.1"
|
||||
|
||||
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
||||
dir_id = os.path.dirname(__file__)
|
||||
|
@ -525,7 +564,7 @@ def main(opt):
|
|||
copy_source(__file__, output_dir)
|
||||
logger = setup_logging(output_dir)
|
||||
|
||||
outf_syn, = setup_output_subdirs(output_dir, 'syn')
|
||||
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
|
||||
|
||||
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
||||
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
||||
|
@ -542,64 +581,59 @@ def main(opt):
|
|||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
logger.info("Resume Path:%s" % opt.model)
|
||||
|
||||
resumed_param = torch.load(opt.model)
|
||||
model.load_state_dict(resumed_param['model_state'])
|
||||
|
||||
model.load_state_dict(resumed_param["model_state"])
|
||||
|
||||
ref = None
|
||||
if opt.generate:
|
||||
opt.eval_path = os.path.join(outf_syn, 'samples.pth')
|
||||
opt.eval_path = os.path.join(outf_syn, "samples.pth")
|
||||
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
ref=generate(model, opt)
|
||||
|
||||
ref = generate(model, opt)
|
||||
|
||||
if opt.eval_gen:
|
||||
# Evaluate generation
|
||||
evaluate_gen(opt, ref, logger)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
|
||||
parser.add_argument('--category', default='chair')
|
||||
parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/")
|
||||
parser.add_argument("--category", default="chair")
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
|
||||
parser.add_argument('--workers', type=int, default=16, help='workers')
|
||||
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
|
||||
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
|
||||
parser.add_argument("--workers", type=int, default=16, help="workers")
|
||||
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
|
||||
|
||||
parser.add_argument('--generate',default=True)
|
||||
parser.add_argument('--eval_gen', default=True)
|
||||
parser.add_argument("--generate", default=True)
|
||||
parser.add_argument("--eval_gen", default=True)
|
||||
|
||||
parser.add_argument('--nc', default=3)
|
||||
parser.add_argument('--npoints', default=2048)
|
||||
'''model'''
|
||||
parser.add_argument('--beta_start', default=0.0001)
|
||||
parser.add_argument('--beta_end', default=0.02)
|
||||
parser.add_argument('--schedule_type', default='linear')
|
||||
parser.add_argument('--time_num', default=1000)
|
||||
parser.add_argument("--nc", default=3)
|
||||
parser.add_argument("--npoints", default=2048)
|
||||
"""model"""
|
||||
parser.add_argument("--beta_start", default=0.0001)
|
||||
parser.add_argument("--beta_end", default=0.02)
|
||||
parser.add_argument("--schedule_type", default="linear")
|
||||
parser.add_argument("--time_num", default=1000)
|
||||
|
||||
#params
|
||||
parser.add_argument('--attention', default=True)
|
||||
parser.add_argument('--dropout', default=0.1)
|
||||
parser.add_argument('--embed_dim', type=int, default=64)
|
||||
parser.add_argument('--loss_type', default='mse')
|
||||
parser.add_argument('--model_mean_type', default='eps')
|
||||
parser.add_argument('--model_var_type', default='fixedsmall')
|
||||
# params
|
||||
parser.add_argument("--attention", default=True)
|
||||
parser.add_argument("--dropout", default=0.1)
|
||||
parser.add_argument("--embed_dim", type=int, default=64)
|
||||
parser.add_argument("--loss_type", default="mse")
|
||||
parser.add_argument("--model_mean_type", default="eps")
|
||||
parser.add_argument("--model_var_type", default="fixedsmall")
|
||||
|
||||
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
|
||||
|
||||
parser.add_argument('--model', default='',required=True, help="path to model (to continue training)")
|
||||
"""eval"""
|
||||
|
||||
'''eval'''
|
||||
parser.add_argument("--eval_path", default="")
|
||||
|
||||
parser.add_argument('--eval_path',
|
||||
default='')
|
||||
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
|
||||
|
||||
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
|
||||
|
||||
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
|
||||
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
@ -609,7 +643,9 @@ def parse_args():
|
|||
opt.cuda = False
|
||||
|
||||
return opt
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_args()
|
||||
set_seed(opt)
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -1,35 +1,33 @@
|
|||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
from shutil import copyfile
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger()
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def set_global_gpu_env(opt):
|
||||
|
||||
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
|
||||
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
|
||||
|
||||
torch.cuda.set_device(opt.gpu)
|
||||
|
||||
|
||||
def copy_source(file, output_dir):
|
||||
copyfile(file, os.path.join(output_dir, os.path.basename(file)))
|
||||
|
||||
|
||||
|
||||
def setup_logging(output_dir):
|
||||
log_format = logging.Formatter("%(asctime)s : %(message)s")
|
||||
logger = logging.getLogger()
|
||||
logger.handlers = []
|
||||
output_file = os.path.join(output_dir, 'output.log')
|
||||
output_file = os.path.join(output_dir, "output.log")
|
||||
file_handler = logging.FileHandler(output_file)
|
||||
file_handler.setFormatter(log_format)
|
||||
logger.addHandler(file_handler)
|
||||
|
@ -44,16 +42,14 @@ def setup_logging(output_dir):
|
|||
|
||||
|
||||
def get_output_dir(prefix, exp_id):
|
||||
t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
output_dir = os.path.join(prefix, 'output/' + exp_id, t)
|
||||
t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
output_dir = os.path.join(prefix, "output/" + exp_id, t)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
return output_dir
|
||||
|
||||
|
||||
|
||||
def set_seed(opt):
|
||||
|
||||
if opt.manualSeed is None:
|
||||
opt.manualSeed = random.randint(1, 10000)
|
||||
print("Random Seed: ", opt.manualSeed)
|
||||
|
@ -65,8 +61,8 @@ def set_seed(opt):
|
|||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def setup_output_subdirs(output_dir, *subfolders):
|
||||
|
||||
def setup_output_subdirs(output_dir, *subfolders):
|
||||
output_subdirs = output_dir
|
||||
try:
|
||||
os.makedirs(output_subdirs)
|
||||
|
@ -82,4 +78,4 @@ def setup_output_subdirs(output_dir, *subfolders):
|
|||
pass
|
||||
subfolder_list.append(curr_subf)
|
||||
|
||||
return subfolder_list
|
||||
return subfolder_list
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
import numpy as np
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import entropy
|
||||
|
||||
|
||||
def iterate_in_chunks(l, n):
|
||||
'''Yield successive 'n'-sized chunks from iterable 'l'.
|
||||
"""Yield successive 'n'-sized chunks from iterable 'l'.
|
||||
Note: last chunk will be smaller than l if n doesn't divide l perfectly.
|
||||
'''
|
||||
"""
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
yield l[i : i + n]
|
||||
|
||||
|
||||
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
||||
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
||||
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
||||
that is placed in the unit-cube.
|
||||
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
|
||||
'''
|
||||
"""
|
||||
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
|
||||
spacing = 1.0 / float(resolution - 1)
|
||||
for i in range(resolution):
|
||||
|
@ -30,9 +32,11 @@ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
|||
|
||||
return grid, spacing
|
||||
|
||||
def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False,
|
||||
use_EMD=False):
|
||||
'''Computes the MMD between two sets of point-clouds.
|
||||
|
||||
def minimum_mathing_distance(
|
||||
sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False
|
||||
):
|
||||
"""Computes the MMD between two sets of point-clouds.
|
||||
Args:
|
||||
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and
|
||||
compared to a set of "reference" point-clouds.
|
||||
|
@ -49,17 +53,17 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
|
|||
use_EMD (boolean: If true, the matchings are based on the EMD.
|
||||
Returns:
|
||||
A tuple containing the MMD and all the matched distances of which the MMD is their mean.
|
||||
'''
|
||||
"""
|
||||
|
||||
n_ref, n_pc_points, pc_dim = ref_pcs.shape
|
||||
_, n_pc_points_s, pc_dim_s = sample_pcs.shape
|
||||
|
||||
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
|
||||
raise ValueError('Incompatible size of point-clouds.')
|
||||
raise ValueError("Incompatible size of point-clouds.")
|
||||
|
||||
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize,
|
||||
sess=sess, use_sqrt=use_sqrt,
|
||||
use_EMD=use_EMD)
|
||||
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(
|
||||
n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
|
||||
)
|
||||
matched_dists = []
|
||||
for i in range(n_ref):
|
||||
best_in_all_batches = []
|
||||
|
@ -75,9 +79,18 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
|
|||
return mmd, matched_dists
|
||||
|
||||
|
||||
def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False,
|
||||
ret_dist=False):
|
||||
'''Computes the Coverage between two sets of point-clouds.
|
||||
def coverage(
|
||||
sample_pcs,
|
||||
ref_pcs,
|
||||
batch_size,
|
||||
normalize=True,
|
||||
sess=None,
|
||||
verbose=False,
|
||||
use_sqrt=False,
|
||||
use_EMD=False,
|
||||
ret_dist=False,
|
||||
):
|
||||
"""Computes the Coverage between two sets of point-clouds.
|
||||
Args:
|
||||
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched
|
||||
and compared to a set of "reference" point-clouds.
|
||||
|
@ -97,18 +110,16 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
|
|||
Returns: the coverage score (int),
|
||||
the indices of the ref_pcs that are matched with each sample_pc
|
||||
and optionally the matched distances of the samples_pcs.
|
||||
'''
|
||||
"""
|
||||
n_ref, n_pc_points, pc_dim = ref_pcs.shape
|
||||
n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape
|
||||
|
||||
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
|
||||
raise ValueError('Incompatible Point-Clouds.')
|
||||
raise ValueError("Incompatible Point-Clouds.")
|
||||
|
||||
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points,
|
||||
normalize=normalize,
|
||||
sess=sess,
|
||||
use_sqrt=use_sqrt,
|
||||
use_EMD=use_EMD)
|
||||
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(
|
||||
n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
|
||||
)
|
||||
matched_gt = []
|
||||
matched_dist = []
|
||||
for i in xrange(n_sam):
|
||||
|
@ -140,12 +151,12 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
|
|||
|
||||
|
||||
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
|
||||
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
|
||||
"""Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
|
||||
Args:
|
||||
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
|
||||
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
|
||||
resolution: (int) grid-resolution. Affects granularity of measurements.
|
||||
'''
|
||||
"""
|
||||
in_unit_sphere = True
|
||||
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
|
||||
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
|
||||
|
@ -153,19 +164,19 @@ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
|
|||
|
||||
|
||||
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
||||
'''Given a collection of point-clouds, estimate the entropy of the random variables
|
||||
"""Given a collection of point-clouds, estimate the entropy of the random variables
|
||||
corresponding to occupancy-grid activation patterns.
|
||||
Inputs:
|
||||
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
|
||||
grid_resolution (int) size of occupancy grid that will be used.
|
||||
'''
|
||||
"""
|
||||
epsilon = 10e-4
|
||||
bound = 0.5 + epsilon
|
||||
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
||||
warnings.warn('Point-clouds are not in unit cube.')
|
||||
warnings.warn("Point-clouds are not in unit cube.")
|
||||
|
||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
||||
warnings.warn('Point-clouds are not in unit sphere.')
|
||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
|
||||
warnings.warn("Point-clouds are not in unit sphere.")
|
||||
|
||||
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
||||
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
||||
|
@ -192,13 +203,14 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
|||
|
||||
return acc_entropy / len(grid_counters), grid_counters
|
||||
|
||||
|
||||
def jensen_shannon_divergence(P, Q):
|
||||
if np.any(P < 0) or np.any(Q < 0):
|
||||
raise ValueError('Negative values.')
|
||||
raise ValueError("Negative values.")
|
||||
if len(P) != len(Q):
|
||||
raise ValueError('Non equal size.')
|
||||
raise ValueError("Non equal size.")
|
||||
|
||||
P_ = P / np.sum(P) # Ensure probabilities.
|
||||
P_ = P / np.sum(P) # Ensure probabilities.
|
||||
Q_ = Q / np.sum(Q)
|
||||
|
||||
e1 = entropy(P_, base=2)
|
||||
|
@ -209,13 +221,14 @@ def jensen_shannon_divergence(P, Q):
|
|||
res2 = _jsdiv(P_, Q_)
|
||||
|
||||
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
||||
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
||||
warnings.warn("Numerical values of two JSD methods don't agree.")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _jsdiv(P, Q):
|
||||
'''another way of computing JSD'''
|
||||
"""another way of computing JSD"""
|
||||
|
||||
def _kldiv(A, B):
|
||||
a = A.copy()
|
||||
b = B.copy()
|
||||
|
@ -229,4 +242,4 @@ def _jsdiv(P, Q):
|
|||
|
||||
M = 0.5 * (P_ + Q_)
|
||||
|
||||
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
||||
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
||||
|
|
|
@ -1,32 +1,33 @@
|
|||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||
import numpy as np
|
||||
|
||||
matplotlib.use("agg")
|
||||
import os
|
||||
import trimesh
|
||||
from pathlib import Path
|
||||
|
||||
'''
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
"""
|
||||
Custom visualization
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def export_to_pc_batch(dir, pcs, colors=None):
|
||||
|
||||
Path(dir).mkdir(parents=True, exist_ok=True)
|
||||
for i, xyz in enumerate(pcs):
|
||||
if colors is None:
|
||||
color = None
|
||||
else:
|
||||
color = colors[i]
|
||||
pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color)
|
||||
pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color)
|
||||
|
||||
|
||||
def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
|
||||
'''
|
||||
def export_to_obj(dir, meshes, transform=lambda v, f: (v, f)):
|
||||
"""
|
||||
transform: f(vertices, faces) --> transformed (vertices, faces)
|
||||
'''
|
||||
"""
|
||||
Path(dir).mkdir(parents=True, exist_ok=True)
|
||||
for i, data in enumerate(meshes):
|
||||
v, f = transform(data[0], data[1])
|
||||
|
@ -36,14 +37,15 @@ def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
|
|||
v_color = None
|
||||
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
||||
out = trimesh.exchange.obj.export_obj(mesh)
|
||||
with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f:
|
||||
with open(os.path.join(dir, "sample_" + str(i) + ".obj"), "w") as f:
|
||||
f.write(out)
|
||||
f.close()
|
||||
|
||||
def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
|
||||
'''
|
||||
|
||||
def export_to_obj_single(path, data, transform=lambda v, f: (v, f)):
|
||||
"""
|
||||
transform: f(vertices, faces) --> transformed (vertices, faces)
|
||||
'''
|
||||
"""
|
||||
v, f = transform(data[0], data[1])
|
||||
if len(data) > 2:
|
||||
v_color = data[2]
|
||||
|
@ -51,15 +53,15 @@ def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
|
|||
v_color = None
|
||||
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
||||
out = trimesh.exchange.obj.export_obj(mesh)
|
||||
with open(path, 'w') as f:
|
||||
with open(path, "w") as f:
|
||||
f.write(out)
|
||||
f.close()
|
||||
|
||||
|
||||
def meshwrite(filename, verts, faces, norms, colors):
|
||||
"""Save a 3D mesh to a polygon .ply file.
|
||||
"""
|
||||
"""Save a 3D mesh to a polygon .ply file."""
|
||||
# Write header
|
||||
ply_file = open(filename, 'w')
|
||||
ply_file = open(filename, "w")
|
||||
ply_file.write("ply\n")
|
||||
ply_file.write("format ascii 1.0\n")
|
||||
ply_file.write("element vertex %d\n" % (verts.shape[0]))
|
||||
|
@ -78,11 +80,20 @@ def meshwrite(filename, verts, faces, norms, colors):
|
|||
|
||||
# Write vertex list
|
||||
for i in range(verts.shape[0]):
|
||||
ply_file.write("%f %f %f %f %f %f %d %d %d\n" % (
|
||||
verts[i, 0], verts[i, 1], verts[i, 2],
|
||||
norms[i, 0], norms[i, 1], norms[i, 2],
|
||||
colors[i, 0], colors[i, 1], colors[i, 2],
|
||||
))
|
||||
ply_file.write(
|
||||
"%f %f %f %f %f %f %d %d %d\n"
|
||||
% (
|
||||
verts[i, 0],
|
||||
verts[i, 1],
|
||||
verts[i, 2],
|
||||
norms[i, 0],
|
||||
norms[i, 1],
|
||||
norms[i, 2],
|
||||
colors[i, 0],
|
||||
colors[i, 1],
|
||||
colors[i, 2],
|
||||
)
|
||||
)
|
||||
|
||||
# Write face list
|
||||
for i in range(faces.shape[0]):
|
||||
|
@ -92,14 +103,13 @@ def meshwrite(filename, verts, faces, norms, colors):
|
|||
|
||||
|
||||
def pcwrite(filename, xyz, rgb=None):
|
||||
"""Save a point cloud to a polygon .ply file.
|
||||
"""
|
||||
"""Save a point cloud to a polygon .ply file."""
|
||||
if rgb is None:
|
||||
rgb = np.ones_like(xyz) * 128
|
||||
rgb = rgb.astype(np.uint8)
|
||||
|
||||
# Write header
|
||||
ply_file = open(filename, 'w')
|
||||
ply_file = open(filename, "w")
|
||||
ply_file.write("ply\n")
|
||||
ply_file.write("format ascii 1.0\n")
|
||||
ply_file.write("element vertex %d\n" % (xyz.shape[0]))
|
||||
|
@ -113,60 +123,67 @@ def pcwrite(filename, xyz, rgb=None):
|
|||
|
||||
# Write vertex list
|
||||
for i in range(xyz.shape[0]):
|
||||
ply_file.write("%f %f %f %d %d %d\n" % (
|
||||
xyz[i, 0], xyz[i, 1], xyz[i, 2],
|
||||
rgb[i, 0], rgb[i, 1], rgb[i, 2],
|
||||
))
|
||||
ply_file.write(
|
||||
"%f %f %f %d %d %d\n"
|
||||
% (
|
||||
xyz[i, 0],
|
||||
xyz[i, 1],
|
||||
xyz[i, 2],
|
||||
rgb[i, 0],
|
||||
rgb[i, 1],
|
||||
rgb[i, 2],
|
||||
)
|
||||
)
|
||||
|
||||
'''
|
||||
|
||||
"""
|
||||
Matplotlib Visualization
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
|
||||
r''' Visualizes voxel data.
|
||||
r"""Visualizes voxel data.
|
||||
show only first num_shown
|
||||
'''
|
||||
batch_size =voxels.shape[0]
|
||||
"""
|
||||
batch_size = voxels.shape[0]
|
||||
voxels = voxels.squeeze(1) > threshold
|
||||
|
||||
num_shown = min(num_shown, batch_size)
|
||||
|
||||
n = int(np.sqrt(num_shown))
|
||||
fig = plt.figure(figsize=(20,20))
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for idx, pc in enumerate(voxels[:num_shown]):
|
||||
if idx >= n*n:
|
||||
if idx >= n * n:
|
||||
break
|
||||
pc = voxels[idx]
|
||||
ax = fig.add_subplot(n, n, idx + 1, projection='3d')
|
||||
ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5)
|
||||
ax = fig.add_subplot(n, n, idx + 1, projection="3d")
|
||||
ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5)
|
||||
ax.view_init()
|
||||
ax.axis('off')
|
||||
plt.savefig(out_file, bbox_inches='tight')
|
||||
ax.axis("off")
|
||||
plt.savefig(out_file, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
def visualize_pointcloud(points, normals=None,
|
||||
out_file=None, show=False, elev=30, azim=225):
|
||||
r''' Visualizes point cloud data.
|
||||
|
||||
def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225):
|
||||
r"""Visualizes point cloud data.
|
||||
Args:
|
||||
points (tensor): point data
|
||||
normals (tensor): normal data (if existing)
|
||||
out_file (string): output file
|
||||
show (bool): whether the plot should be shown
|
||||
'''
|
||||
"""
|
||||
# Create plot
|
||||
fig = plt.figure()
|
||||
ax = fig.gca(projection=Axes3D.name)
|
||||
ax.scatter(points[:, 2], points[:, 0], points[:, 1])
|
||||
if normals is not None:
|
||||
ax.quiver(
|
||||
points[:, 2], points[:, 0], points[:, 1],
|
||||
normals[:, 2], normals[:, 0], normals[:, 1],
|
||||
length=0.1, color='k'
|
||||
points[:, 2], points[:, 0], points[:, 1], normals[:, 2], normals[:, 0], normals[:, 1], length=0.1, color="k"
|
||||
)
|
||||
ax.set_xlabel('Z')
|
||||
ax.set_ylabel('X')
|
||||
ax.set_zlabel('Y')
|
||||
ax.set_xlabel("Z")
|
||||
ax.set_ylabel("X")
|
||||
ax.set_zlabel("Y")
|
||||
# ax.set_xlim(-0.5, 0.5)
|
||||
# ax.set_ylim(-0.5, 0.5)
|
||||
# ax.set_zlim(-0.5, 0.5)
|
||||
|
@ -178,37 +195,39 @@ def visualize_pointcloud(points, normals=None,
|
|||
plt.close(fig)
|
||||
|
||||
|
||||
def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225):
|
||||
def visualize_pointcloud_batch(
|
||||
path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225
|
||||
):
|
||||
batch_size = len(pointclouds)
|
||||
fig = plt.figure(figsize=(20,20))
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
ncols = int(np.sqrt(batch_size))
|
||||
nrows = max(1, (batch_size-1) // ncols+1)
|
||||
nrows = max(1, (batch_size - 1) // ncols + 1)
|
||||
for idx, pc in enumerate(pointclouds):
|
||||
if vis_label:
|
||||
label = categories[labels[idx].item()]
|
||||
pred = categories[pred_labels[idx]]
|
||||
colour = 'g' if label == pred else 'r'
|
||||
colour = "g" if label == pred else "r"
|
||||
elif target is None:
|
||||
|
||||
colour = 'g'
|
||||
colour = "g"
|
||||
else:
|
||||
colour = target[idx]
|
||||
pc = pc.cpu().numpy()
|
||||
ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d')
|
||||
ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d")
|
||||
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
|
||||
ax.view_init(elev=elev, azim=azim)
|
||||
ax.axis('off')
|
||||
ax.axis("off")
|
||||
if vis_label:
|
||||
ax.set_title('GT: {0}\nPred: {1}'.format(label, pred))
|
||||
ax.set_title("GT: {0}\nPred: {1}".format(label, pred))
|
||||
|
||||
plt.savefig(path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
Plot stats
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def plot_stats(output_dir, stats, interval):
|
||||
content = stats.keys()
|
||||
|
@ -218,5 +237,5 @@ def plot_stats(output_dir, stats, interval):
|
|||
axs[j].plot(interval, v)
|
||||
axs[j].set_ylabel(k)
|
||||
|
||||
f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight')
|
||||
f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight")
|
||||
plt.close(f)
|
||||
|
|
Loading…
Reference in a new issue