style: autoformatting
This commit is contained in:
parent
d887d74852
commit
2fbfc320f2
|
@ -1,32 +1,33 @@
|
||||||
|
|
||||||
from glob import glob
|
|
||||||
import re
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def raw_camparam_from_xml(path, pose="lookAt"):
|
def raw_camparam_from_xml(path, pose="lookAt"):
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
tree = ET.parse(path)
|
tree = ET.parse(path)
|
||||||
elm = tree.find("./sensor/transform/" + pose)
|
elm = tree.find("./sensor/transform/" + pose)
|
||||||
camparam = elm.attrib
|
camparam = elm.attrib
|
||||||
origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
|
origin = np.fromstring(camparam["origin"], dtype=np.float32, sep=",")
|
||||||
target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
|
target = np.fromstring(camparam["target"], dtype=np.float32, sep=",")
|
||||||
up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
|
up = np.fromstring(camparam["up"], dtype=np.float32, sep=",")
|
||||||
height = int(
|
height = int(tree.find("./sensor/film/integer[@name='height']").attrib["value"])
|
||||||
tree.find("./sensor/film/integer[@name='height']").attrib['value'])
|
width = int(tree.find("./sensor/film/integer[@name='width']").attrib["value"])
|
||||||
width = int(
|
|
||||||
tree.find("./sensor/film/integer[@name='width']").attrib['value'])
|
|
||||||
|
|
||||||
camparam = dict()
|
camparam = dict()
|
||||||
camparam['origin'] = origin
|
camparam["origin"] = origin
|
||||||
camparam['up'] = up
|
camparam["up"] = up
|
||||||
camparam['target'] = target
|
camparam["target"] = target
|
||||||
camparam['height'] = height
|
camparam["height"] = height
|
||||||
camparam['width'] = width
|
camparam["width"] = width
|
||||||
return camparam
|
return camparam
|
||||||
|
|
||||||
|
|
||||||
def get_cam_pos(origin, target, up):
|
def get_cam_pos(origin, target, up):
|
||||||
inward = origin - target
|
inward = origin - target
|
||||||
right = np.cross(up, inward)
|
right = np.cross(up, inward)
|
||||||
|
@ -38,59 +39,54 @@ def get_cam_pos(origin, target, up):
|
||||||
ry /= np.linalg.norm(ry)
|
ry /= np.linalg.norm(ry)
|
||||||
rz /= np.linalg.norm(rz)
|
rz /= np.linalg.norm(rz)
|
||||||
|
|
||||||
rot = np.stack([
|
rot = np.stack([rx, ry, -rz], axis=0)
|
||||||
rx,
|
|
||||||
ry,
|
|
||||||
-rz
|
|
||||||
], axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
aff = np.concatenate([
|
|
||||||
np.eye(3), -origin[:,None]
|
|
||||||
], axis=1)
|
|
||||||
|
|
||||||
|
aff = np.concatenate([np.eye(3), -origin[:, None]], axis=1)
|
||||||
|
|
||||||
ext = np.matmul(rot, aff)
|
ext = np.matmul(rot, aff)
|
||||||
|
|
||||||
result = np.concatenate(
|
result = np.concatenate([ext, np.array([[0, 0, 0, 1]])], axis=0)
|
||||||
[ext, np.array([[0,0,0,1]])], axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
|
def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
|
||||||
depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png')))
|
depths = sorted(glob(os.path.join(datapoint_dir, "*depth.png")))
|
||||||
cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths]
|
cam_ext = [
|
||||||
|
"_".join(re.sub(dataroot.strip("/"), camera_param_dir.strip("/"), f).split("_")[:-1]) + ".xml" for f in depths
|
||||||
|
]
|
||||||
|
|
||||||
for i, (f, pth) in enumerate(zip(cam_ext, depths)):
|
for i, (f, pth) in enumerate(zip(cam_ext, depths)):
|
||||||
if not os.path.exists(f):
|
if not os.path.exists(f):
|
||||||
continue
|
continue
|
||||||
params=raw_camparam_from_xml(f)
|
params = raw_camparam_from_xml(f)
|
||||||
origin, target, up, width, height = params['origin'], params['target'], params['up'],\
|
origin, target, up, width, height = (
|
||||||
params['width'], params['height']
|
params["origin"],
|
||||||
|
params["target"],
|
||||||
|
params["up"],
|
||||||
|
params["width"],
|
||||||
|
params["height"],
|
||||||
|
)
|
||||||
|
|
||||||
ext_matrix = get_cam_pos(origin, target, up)
|
ext_matrix = get_cam_pos(origin, target, up)
|
||||||
|
|
||||||
#####
|
#####
|
||||||
diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5
|
diag = (0.036**2 + 0.024**2) ** 0.5
|
||||||
focal_length = 0.05
|
focal_length = 0.05
|
||||||
res = [480, 480]
|
res = [480, 480]
|
||||||
h_relative = (res[1] / res[0])
|
h_relative = res[1] / res[0]
|
||||||
sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2))
|
sensor_width = np.sqrt(diag**2 / (1 + h_relative**2))
|
||||||
pix_size = sensor_width / res[0]
|
pix_size = sensor_width / res[0]
|
||||||
|
|
||||||
K = np.array([
|
K = np.array(
|
||||||
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
|
[
|
||||||
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
|
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
|
||||||
[0, 0, 1]
|
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
|
||||||
])
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K)
|
np.savez(pth.split("depth.png")[0] + "cam_params.npz", extr=ext_matrix, intr=K)
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
|
@ -102,21 +98,16 @@ def main(opt):
|
||||||
if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
|
if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
|
||||||
leaf_subdirs.append(dirpath)
|
leaf_subdirs.append(dirpath)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for k, dir_ in enumerate(leaf_subdirs):
|
for k, dir_ in enumerate(leaf_subdirs):
|
||||||
print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_))
|
print("Processing dir {}/{}: {}".format(k, len(leaf_subdirs), dir_))
|
||||||
|
|
||||||
convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
|
convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = argparse.ArgumentParser()
|
args = argparse.ArgumentParser()
|
||||||
args.add_argument('--dataroot', type=str, default='GenReData/')
|
args.add_argument("--dataroot", type=str, default="GenReData/")
|
||||||
args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2')
|
args.add_argument("--mitsuba_xml_root", type=str, default="GenReData/genre-xml_v2")
|
||||||
|
|
||||||
opt = args.parse_args()
|
opt = args.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
import trimesh
|
import trimesh
|
||||||
from plyfile import PlyData, PlyElement
|
from plyfile import PlyData, PlyElement
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
def project_pc_to_image(points, resolution=64):
|
def project_pc_to_image(points, resolution=64):
|
||||||
"""project point clouds into 2D image
|
"""project point clouds into 2D image
|
||||||
|
@ -26,29 +28,32 @@ def project_pc_to_image(points, resolution=64):
|
||||||
|
|
||||||
|
|
||||||
def write_ply(points, filename, text=False):
|
def write_ply(points, filename, text=False):
|
||||||
""" input: Nx3, write points to filename as PLY format. """
|
"""input: Nx3, write points to filename as PLY format."""
|
||||||
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])]
|
points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
|
||||||
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')])
|
vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
|
||||||
el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
|
el = PlyElement.describe(vertex, "vertex", comments=["vertices"])
|
||||||
with open(filename, mode='wb') as f:
|
with open(filename, mode="wb") as f:
|
||||||
PlyData([el], text=text).write(f)
|
PlyData([el], text=text).write(f)
|
||||||
|
|
||||||
|
|
||||||
def rotate_point_cloud(points, transformation_mat):
|
def rotate_point_cloud(points, transformation_mat):
|
||||||
|
|
||||||
new_points = np.dot(transformation_mat, points.T).T
|
new_points = np.dot(transformation_mat, points.T).T
|
||||||
|
|
||||||
return new_points
|
return new_points
|
||||||
|
|
||||||
|
|
||||||
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
|
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
|
||||||
""" align 3depn shapes to shapenet coordinates"""
|
"""align 3depn shapes to shapenet coordinates"""
|
||||||
# angle = math.radians(angle_deg)
|
# angle = math.radians(angle_deg)
|
||||||
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
|
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
|
||||||
# rot_m = rot_m.to_matrix()
|
# rot_m = rot_m.to_matrix()
|
||||||
rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00],
|
rot_m = np.array(
|
||||||
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00],
|
[
|
||||||
[-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]])
|
[2.22044605e-16, 0.00000000e00, 1.00000000e00],
|
||||||
|
[0.00000000e00, 1.00000000e00, 0.00000000e00],
|
||||||
|
[-1.00000000e00, 0.00000000e00, 2.22044605e-16],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
new_points = rotate_point_cloud(points, rot_m)
|
new_points = rotate_point_cloud(points, rot_m)
|
||||||
|
|
||||||
|
@ -87,14 +92,13 @@ def sample_point_cloud_by_n(points, n_pts):
|
||||||
return points
|
return points
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def collect_data_id(split_dir, classname, phase):
|
def collect_data_id(split_dir, classname, phase):
|
||||||
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
|
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
|
||||||
if not os.path.exists(filename):
|
if not os.path.exists(filename):
|
||||||
raise ValueError("Invalid filepath: {}".format(filename))
|
raise ValueError("Invalid filepath: {}".format(filename))
|
||||||
|
|
||||||
all_ids = []
|
all_ids = []
|
||||||
with open(filename, 'r') as fp:
|
with open(filename, "r") as fp:
|
||||||
info = json.load(fp)
|
info = json.load(fp)
|
||||||
for item in info:
|
for item in info:
|
||||||
all_ids.append(item["anno_id"])
|
all_ids.append(item["anno_id"])
|
||||||
|
@ -102,7 +106,6 @@ def collect_data_id(split_dir, classname, phase):
|
||||||
return all_ids
|
return all_ids
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GANdatasetPartNet(Dataset):
|
class GANdatasetPartNet(Dataset):
|
||||||
def __init__(self, phase, data_root, category, n_pts):
|
def __init__(self, phase, data_root, category, n_pts):
|
||||||
super(GANdatasetPartNet, self).__init__()
|
super(GANdatasetPartNet, self).__init__()
|
||||||
|
@ -114,10 +117,12 @@ class GANdatasetPartNet(Dataset):
|
||||||
|
|
||||||
self.data_root = data_root
|
self.data_root = data_root
|
||||||
|
|
||||||
shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase)
|
shape_names = collect_data_id(
|
||||||
|
os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase
|
||||||
|
)
|
||||||
self.shape_names = []
|
self.shape_names = []
|
||||||
for name in shape_names:
|
for name in shape_names:
|
||||||
path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name)
|
path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
self.shape_names.append(name)
|
self.shape_names.append(name)
|
||||||
|
|
||||||
|
@ -129,12 +134,12 @@ class GANdatasetPartNet(Dataset):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_point_cloud(path):
|
def load_point_cloud(path):
|
||||||
pc = trimesh.load(path)
|
pc = trimesh.load(path)
|
||||||
pc = pc.vertices / 2.0 # scale to unit sphere
|
pc = pc.vertices / 2.0 # scale to unit sphere
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def read_point_cloud_part_label(path):
|
def read_point_cloud_part_label(path):
|
||||||
with open(path, 'r') as fp:
|
with open(path, "r") as fp:
|
||||||
labels = fp.readlines()
|
labels = fp.readlines()
|
||||||
labels = np.array([int(x) for x in labels])
|
labels = np.array([int(x) for x in labels])
|
||||||
return labels
|
return labels
|
||||||
|
@ -156,26 +161,31 @@ class GANdatasetPartNet(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
raw_shape_name = self.shape_names[index]
|
raw_shape_name = self.shape_names[index]
|
||||||
raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply')
|
raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply")
|
||||||
raw_pc = self.load_point_cloud(raw_ply_path)
|
raw_pc = self.load_point_cloud(raw_ply_path)
|
||||||
|
|
||||||
raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt')
|
raw_label_path = os.path.join(
|
||||||
|
self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt"
|
||||||
|
)
|
||||||
part_labels = self.read_point_cloud_part_label(raw_label_path)
|
part_labels = self.read_point_cloud_part_label(raw_label_path)
|
||||||
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
|
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
|
||||||
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
|
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
|
||||||
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
|
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
|
||||||
|
|
||||||
real_shape_name = self.shape_names[index]
|
real_shape_name = self.shape_names[index]
|
||||||
real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply')
|
real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply")
|
||||||
real_pc = self.load_point_cloud(real_ply_path)
|
real_pc = self.load_point_cloud(real_ply_path)
|
||||||
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
|
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
|
||||||
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
|
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
|
||||||
|
|
||||||
return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name,
|
return {
|
||||||
'n_part_keep': n_part_keep, 'idx': index}
|
"raw": raw_pc,
|
||||||
|
"real": real_pc,
|
||||||
|
"raw_id": raw_shape_name,
|
||||||
|
"real_id": real_shape_name,
|
||||||
|
"n_part_keep": n_part_keep,
|
||||||
|
"idx": index,
|
||||||
|
}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.shape_names)
|
return len(self.shape_names)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,33 +1,67 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torch.utils import data
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
|
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
|
||||||
synsetid_to_cate = {
|
synsetid_to_cate = {
|
||||||
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
|
"02691156": "airplane",
|
||||||
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
|
"02773838": "bag",
|
||||||
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
|
"02801938": "basket",
|
||||||
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
|
"02808440": "bathtub",
|
||||||
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
|
"02818832": "bed",
|
||||||
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
|
"02828884": "bench",
|
||||||
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
|
"02876657": "bottle",
|
||||||
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
|
"02880940": "bowl",
|
||||||
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
|
"02924116": "bus",
|
||||||
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
|
"02933112": "cabinet",
|
||||||
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
|
"02747177": "can",
|
||||||
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
|
"02942699": "camera",
|
||||||
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
|
"02954340": "cap",
|
||||||
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
|
"02958343": "car",
|
||||||
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
|
"03001627": "chair",
|
||||||
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
|
"03046257": "clock",
|
||||||
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
|
"03207941": "dishwasher",
|
||||||
'04554684': 'washer', '02992529': 'cellphone',
|
"03211117": "monitor",
|
||||||
'02843684': 'birdhouse', '02871439': 'bookshelf',
|
"04379243": "table",
|
||||||
|
"04401088": "telephone",
|
||||||
|
"02946921": "tin_can",
|
||||||
|
"04460130": "tower",
|
||||||
|
"04468005": "train",
|
||||||
|
"03085013": "keyboard",
|
||||||
|
"03261776": "earphone",
|
||||||
|
"03325088": "faucet",
|
||||||
|
"03337140": "file",
|
||||||
|
"03467517": "guitar",
|
||||||
|
"03513137": "helmet",
|
||||||
|
"03593526": "jar",
|
||||||
|
"03624134": "knife",
|
||||||
|
"03636649": "lamp",
|
||||||
|
"03642806": "laptop",
|
||||||
|
"03691459": "speaker",
|
||||||
|
"03710193": "mailbox",
|
||||||
|
"03759954": "microphone",
|
||||||
|
"03761084": "microwave",
|
||||||
|
"03790512": "motorcycle",
|
||||||
|
"03797390": "mug",
|
||||||
|
"03928116": "piano",
|
||||||
|
"03938244": "pillow",
|
||||||
|
"03948459": "pistol",
|
||||||
|
"03991062": "pot",
|
||||||
|
"04004475": "printer",
|
||||||
|
"04074963": "remote_control",
|
||||||
|
"04090263": "rifle",
|
||||||
|
"04099429": "rocket",
|
||||||
|
"04225987": "skateboard",
|
||||||
|
"04256520": "sofa",
|
||||||
|
"04330267": "stove",
|
||||||
|
"04530566": "vessel",
|
||||||
|
"04554684": "washer",
|
||||||
|
"02992529": "cellphone",
|
||||||
|
"02843684": "birdhouse",
|
||||||
|
"02871439": "bookshelf",
|
||||||
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
||||||
# '02834778': 'bicycle', not in our taxonomy
|
# '02834778': 'bicycle', not in our taxonomy
|
||||||
}
|
}
|
||||||
|
@ -35,13 +69,23 @@ cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
|
||||||
|
|
||||||
|
|
||||||
class Uniform15KPC(Dataset):
|
class Uniform15KPC(Dataset):
|
||||||
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
|
def __init__(
|
||||||
te_sample_size=10000, split='train', scale=1.,
|
self,
|
||||||
normalize_per_shape=False, box_per_shape=False,
|
root_dir,
|
||||||
random_subsample=False,
|
subdirs,
|
||||||
normalize_std_per_axis=False,
|
tr_sample_size=10000,
|
||||||
all_points_mean=None, all_points_std=None,
|
te_sample_size=10000,
|
||||||
input_dim=3, use_mask=False):
|
split="train",
|
||||||
|
scale=1.0,
|
||||||
|
normalize_per_shape=False,
|
||||||
|
box_per_shape=False,
|
||||||
|
random_subsample=False,
|
||||||
|
normalize_std_per_axis=False,
|
||||||
|
all_points_mean=None,
|
||||||
|
all_points_std=None,
|
||||||
|
input_dim=3,
|
||||||
|
use_mask=False,
|
||||||
|
):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.split = split
|
self.split = split
|
||||||
self.in_tr_sample_size = tr_sample_size
|
self.in_tr_sample_size = tr_sample_size
|
||||||
|
@ -67,9 +111,9 @@ class Uniform15KPC(Dataset):
|
||||||
|
|
||||||
all_mids = []
|
all_mids = []
|
||||||
for x in os.listdir(sub_path):
|
for x in os.listdir(sub_path):
|
||||||
if not x.endswith('.npy'):
|
if not x.endswith(".npy"):
|
||||||
continue
|
continue
|
||||||
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
|
all_mids.append(os.path.join(self.split, x[: -len(".npy")]))
|
||||||
|
|
||||||
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
|
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
|
||||||
for mid in all_mids:
|
for mid in all_mids:
|
||||||
|
@ -111,7 +155,9 @@ class Uniform15KPC(Dataset):
|
||||||
B, N = self.all_points.shape[:2]
|
B, N = self.all_points.shape[:2]
|
||||||
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
|
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
|
||||||
|
|
||||||
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim)
|
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(
|
||||||
|
axis=1
|
||||||
|
).reshape(B, 1, input_dim)
|
||||||
|
|
||||||
else: # normalize across the dataset
|
else: # normalize across the dataset
|
||||||
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
|
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
|
||||||
|
@ -129,8 +175,7 @@ class Uniform15KPC(Dataset):
|
||||||
self.tr_sample_size = min(10000, tr_sample_size)
|
self.tr_sample_size = min(10000, tr_sample_size)
|
||||||
self.te_sample_size = min(5000, te_sample_size)
|
self.te_sample_size = min(5000, te_sample_size)
|
||||||
print("Total number of data:%d" % len(self.train_points))
|
print("Total number of data:%d" % len(self.train_points))
|
||||||
print("Min number of points: (train)%d (test)%d"
|
print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size))
|
||||||
% (self.tr_sample_size, self.te_sample_size))
|
|
||||||
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
||||||
|
|
||||||
def get_pc_stats(self, idx):
|
def get_pc_stats(self, idx):
|
||||||
|
@ -139,7 +184,6 @@ class Uniform15KPC(Dataset):
|
||||||
s = self.all_points_std[idx].reshape(1, -1)
|
s = self.all_points_std[idx].reshape(1, -1)
|
||||||
return m, s
|
return m, s
|
||||||
|
|
||||||
|
|
||||||
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
|
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
|
||||||
|
|
||||||
def renormalize(self, mean, std):
|
def renormalize(self, mean, std):
|
||||||
|
@ -173,11 +217,14 @@ class Uniform15KPC(Dataset):
|
||||||
sid, mid = self.all_cate_mids[idx]
|
sid, mid = self.all_cate_mids[idx]
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
'idx': idx,
|
"idx": idx,
|
||||||
'train_points': tr_out,
|
"train_points": tr_out,
|
||||||
'test_points': te_out,
|
"test_points": te_out,
|
||||||
'mean': m, 'std': s, 'cate_idx': cate_idx,
|
"mean": m,
|
||||||
'sid': sid, 'mid': mid
|
"std": s,
|
||||||
|
"cate_idx": cate_idx,
|
||||||
|
"sid": sid,
|
||||||
|
"mid": mid,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.use_mask:
|
if self.use_mask:
|
||||||
|
@ -192,26 +239,35 @@ class Uniform15KPC(Dataset):
|
||||||
# out['train_points_masked'] = masked
|
# out['train_points_masked'] = masked
|
||||||
# out['train_masks'] = tr_mask
|
# out['train_masks'] = tr_mask
|
||||||
tr_mask = self.mask_transform(tr_out)
|
tr_mask = self.mask_transform(tr_out)
|
||||||
out['train_masks'] = tr_mask
|
out["train_masks"] = tr_mask
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ShapeNet15kPointClouds(Uniform15KPC):
|
class ShapeNet15kPointClouds(Uniform15KPC):
|
||||||
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
|
def __init__(
|
||||||
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
|
self,
|
||||||
split='train', scale=1., normalize_per_shape=False,
|
root_dir="data/ShapeNetCore.v2.PC15k",
|
||||||
normalize_std_per_axis=False, box_per_shape=False,
|
categories=["airplane"],
|
||||||
random_subsample=False,
|
tr_sample_size=10000,
|
||||||
all_points_mean=None, all_points_std=None,
|
te_sample_size=2048,
|
||||||
use_mask=False):
|
split="train",
|
||||||
|
scale=1.0,
|
||||||
|
normalize_per_shape=False,
|
||||||
|
normalize_std_per_axis=False,
|
||||||
|
box_per_shape=False,
|
||||||
|
random_subsample=False,
|
||||||
|
all_points_mean=None,
|
||||||
|
all_points_std=None,
|
||||||
|
use_mask=False,
|
||||||
|
):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.split = split
|
self.split = split
|
||||||
assert self.split in ['train', 'test', 'val']
|
assert self.split in ["train", "test", "val"]
|
||||||
self.tr_sample_size = tr_sample_size
|
self.tr_sample_size = tr_sample_size
|
||||||
self.te_sample_size = te_sample_size
|
self.te_sample_size = te_sample_size
|
||||||
self.cates = categories
|
self.cates = categories
|
||||||
if 'all' in categories:
|
if "all" in categories:
|
||||||
self.synset_ids = list(cate_to_synsetid.values())
|
self.synset_ids = list(cate_to_synsetid.values())
|
||||||
else:
|
else:
|
||||||
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
|
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
|
||||||
|
@ -221,19 +277,21 @@ class ShapeNet15kPointClouds(Uniform15KPC):
|
||||||
self.display_axis_order = [0, 2, 1]
|
self.display_axis_order = [0, 2, 1]
|
||||||
|
|
||||||
super(ShapeNet15kPointClouds, self).__init__(
|
super(ShapeNet15kPointClouds, self).__init__(
|
||||||
root_dir, self.synset_ids,
|
root_dir,
|
||||||
|
self.synset_ids,
|
||||||
tr_sample_size=tr_sample_size,
|
tr_sample_size=tr_sample_size,
|
||||||
te_sample_size=te_sample_size,
|
te_sample_size=te_sample_size,
|
||||||
split=split, scale=scale,
|
split=split,
|
||||||
normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape,
|
scale=scale,
|
||||||
|
normalize_per_shape=normalize_per_shape,
|
||||||
|
box_per_shape=box_per_shape,
|
||||||
normalize_std_per_axis=normalize_std_per_axis,
|
normalize_std_per_axis=normalize_std_per_axis,
|
||||||
random_subsample=random_subsample,
|
random_subsample=random_subsample,
|
||||||
all_points_mean=all_points_mean, all_points_std=all_points_std,
|
all_points_mean=all_points_mean,
|
||||||
input_dim=3, use_mask=use_mask)
|
all_points_std=all_points_std,
|
||||||
|
input_dim=3,
|
||||||
|
use_mask=use_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
####################################################################################
|
####################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,34 +1,70 @@
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import torch
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
synset_to_label = {
|
synset_to_label = {
|
||||||
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
|
"02691156": "airplane",
|
||||||
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
|
"02773838": "bag",
|
||||||
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
|
"02801938": "basket",
|
||||||
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
|
"02808440": "bathtub",
|
||||||
'02954340': 'cap', '02958343': 'car', '03001627': 'chair',
|
"02818832": "bed",
|
||||||
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
|
"02828884": "bench",
|
||||||
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
|
"02876657": "bottle",
|
||||||
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
|
"02880940": "bowl",
|
||||||
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
|
"02924116": "bus",
|
||||||
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
|
"02933112": "cabinet",
|
||||||
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
|
"02747177": "can",
|
||||||
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
|
"02942699": "camera",
|
||||||
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
|
"02954340": "cap",
|
||||||
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
|
"02958343": "car",
|
||||||
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
|
"03001627": "chair",
|
||||||
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
|
"03046257": "clock",
|
||||||
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
|
"03207941": "dishwasher",
|
||||||
'04554684': 'washer', '02992529': 'cellphone',
|
"03211117": "monitor",
|
||||||
'02843684': 'birdhouse', '02871439': 'bookshelf',
|
"04379243": "table",
|
||||||
|
"04401088": "telephone",
|
||||||
|
"02946921": "tin_can",
|
||||||
|
"04460130": "tower",
|
||||||
|
"04468005": "train",
|
||||||
|
"03085013": "keyboard",
|
||||||
|
"03261776": "earphone",
|
||||||
|
"03325088": "faucet",
|
||||||
|
"03337140": "file",
|
||||||
|
"03467517": "guitar",
|
||||||
|
"03513137": "helmet",
|
||||||
|
"03593526": "jar",
|
||||||
|
"03624134": "knife",
|
||||||
|
"03636649": "lamp",
|
||||||
|
"03642806": "laptop",
|
||||||
|
"03691459": "speaker",
|
||||||
|
"03710193": "mailbox",
|
||||||
|
"03759954": "microphone",
|
||||||
|
"03761084": "microwave",
|
||||||
|
"03790512": "motorcycle",
|
||||||
|
"03797390": "mug",
|
||||||
|
"03928116": "piano",
|
||||||
|
"03938244": "pillow",
|
||||||
|
"03948459": "pistol",
|
||||||
|
"03991062": "pot",
|
||||||
|
"04004475": "printer",
|
||||||
|
"04074963": "remote_control",
|
||||||
|
"04090263": "rifle",
|
||||||
|
"04099429": "rocket",
|
||||||
|
"04225987": "skateboard",
|
||||||
|
"04256520": "sofa",
|
||||||
|
"04330267": "stove",
|
||||||
|
"04530566": "vessel",
|
||||||
|
"04554684": "washer",
|
||||||
|
"02992529": "cellphone",
|
||||||
|
"02843684": "birdhouse",
|
||||||
|
"02871439": "bookshelf",
|
||||||
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
||||||
# '02834778': 'bicycle', not in our taxonomy
|
# '02834778': 'bicycle', not in our taxonomy
|
||||||
}
|
}
|
||||||
|
@ -36,30 +72,44 @@ synset_to_label = {
|
||||||
# Label to Synset mapping (for ShapeNet core classes)
|
# Label to Synset mapping (for ShapeNet core classes)
|
||||||
label_to_synset = {v: k for k, v in synset_to_label.items()}
|
label_to_synset = {v: k for k, v in synset_to_label.items()}
|
||||||
|
|
||||||
|
|
||||||
def _convert_categories(categories):
|
def _convert_categories(categories):
|
||||||
assert categories is not None, 'List of categories cannot be empty!'
|
assert categories is not None, "List of categories cannot be empty!"
|
||||||
if not (c in synset_to_label.keys() + label_to_synset.keys()
|
if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories):
|
||||||
for c in categories):
|
warnings.warn(
|
||||||
warnings.warn('Some or all of the categories requested are not part of \
|
"Some or all of the categories requested are not part of \
|
||||||
ShapeNetCore. Data loading may fail if these categories are not avaliable.')
|
ShapeNetCore. Data loading may fail if these categories are not avaliable."
|
||||||
synsets = [label_to_synset[c] if c in label_to_synset.keys()
|
)
|
||||||
else c for c in categories]
|
synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories]
|
||||||
return synsets
|
return synsets
|
||||||
|
|
||||||
|
|
||||||
class ShapeNet_Multiview_Points(Dataset):
|
class ShapeNet_Multiview_Points(Dataset):
|
||||||
def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val',
|
def __init__(
|
||||||
npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False):
|
self,
|
||||||
|
root_pc: str,
|
||||||
|
root_views: str,
|
||||||
|
cache: str,
|
||||||
|
categories: list = ["chair"],
|
||||||
|
split: str = "val",
|
||||||
|
npoints=2048,
|
||||||
|
sv_samples=800,
|
||||||
|
all_points_mean=None,
|
||||||
|
all_points_std=None,
|
||||||
|
get_image=False,
|
||||||
|
):
|
||||||
self.root = Path(root_views)
|
self.root = Path(root_views)
|
||||||
self.split = split
|
self.split = split
|
||||||
self.get_image = get_image
|
self.get_image = get_image
|
||||||
params = {
|
params = {
|
||||||
'cat': categories,
|
"cat": categories,
|
||||||
'npoints': npoints,
|
"npoints": npoints,
|
||||||
'sv_samples': sv_samples,
|
"sv_samples": sv_samples,
|
||||||
}
|
}
|
||||||
params = tuple(sorted(pair for pair in params.items()))
|
params = tuple(sorted(pair for pair in params.items()))
|
||||||
self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest())
|
self.cache_dir = Path(cache) / "svpoints/{}/{}".format(
|
||||||
|
"_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest()
|
||||||
|
)
|
||||||
|
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.paths = []
|
self.paths = []
|
||||||
|
@ -74,13 +124,12 @@ class ShapeNet_Multiview_Points(Dataset):
|
||||||
|
|
||||||
# loops through desired classes
|
# loops through desired classes
|
||||||
for i in range(len(self.synsets)):
|
for i in range(len(self.synsets)):
|
||||||
|
|
||||||
syn = self.synsets[i]
|
syn = self.synsets[i]
|
||||||
class_target = self.root / syn
|
class_target = self.root / syn
|
||||||
if not class_target.exists():
|
if not class_target.exists():
|
||||||
raise ValueError('Class {0} ({1}) was not found at location {2}.'.format(
|
raise ValueError(
|
||||||
syn, self.labels[i], str(class_target)))
|
"Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target))
|
||||||
|
)
|
||||||
|
|
||||||
sub_path_pc = os.path.join(root_pc, syn, split)
|
sub_path_pc = os.path.join(root_pc, syn, split)
|
||||||
if not os.path.isdir(sub_path_pc):
|
if not os.path.isdir(sub_path_pc):
|
||||||
|
@ -90,30 +139,30 @@ class ShapeNet_Multiview_Points(Dataset):
|
||||||
self.all_mids = []
|
self.all_mids = []
|
||||||
self.imgs = []
|
self.imgs = []
|
||||||
for x in os.listdir(sub_path_pc):
|
for x in os.listdir(sub_path_pc):
|
||||||
if not x.endswith('.npy'):
|
if not x.endswith(".npy"):
|
||||||
continue
|
continue
|
||||||
self.all_mids.append(os.path.join(split, x[:-len('.npy')]))
|
self.all_mids.append(os.path.join(split, x[: -len(".npy")]))
|
||||||
|
|
||||||
for mid in tqdm(self.all_mids):
|
for mid in tqdm(self.all_mids):
|
||||||
# obj_fname = os.path.join(sub_path, x)
|
# obj_fname = os.path.join(sub_path, x)
|
||||||
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
|
obj_fname = os.path.join(root_pc, syn, mid + ".npy")
|
||||||
cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz'))
|
cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz"))
|
||||||
if len(cams_pths) < 20:
|
if len(cams_pths) < 20:
|
||||||
continue
|
continue
|
||||||
point_cloud = np.load(obj_fname)
|
point_cloud = np.load(obj_fname)
|
||||||
sv_points_group = []
|
sv_points_group = []
|
||||||
img_path_group = []
|
img_path_group = []
|
||||||
(self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True)
|
(self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True)
|
||||||
success = True
|
success = True
|
||||||
for i, cp in enumerate(cams_pths):
|
for i, cp in enumerate(cams_pths):
|
||||||
cp = str(cp)
|
cp = str(cp)
|
||||||
vp = cp.split('cam_params')[0] + 'depth.png'
|
vp = cp.split("cam_params")[0] + "depth.png"
|
||||||
depth_minmax_pth = cp.split('_cam_params')[0] + '.npy'
|
depth_minmax_pth = cp.split("_cam_params")[0] + ".npy"
|
||||||
cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) )
|
cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth))
|
||||||
|
|
||||||
cam_params = np.load(cp)
|
cam_params = np.load(cp)
|
||||||
extr = cam_params['extr']
|
extr = cam_params["extr"]
|
||||||
intr = cam_params['intr']
|
intr = cam_params["intr"]
|
||||||
|
|
||||||
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
|
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
|
||||||
|
|
||||||
|
@ -125,7 +174,7 @@ class ShapeNet_Multiview_Points(Dataset):
|
||||||
sv_points_group.append(sv_point_cloud)
|
sv_points_group.append(sv_point_cloud)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
success=False
|
success = False
|
||||||
break
|
break
|
||||||
if not success:
|
if not success:
|
||||||
continue
|
continue
|
||||||
|
@ -144,64 +193,66 @@ class ShapeNet_Multiview_Points(Dataset):
|
||||||
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
|
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
|
||||||
|
|
||||||
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
|
||||||
self.train_points = self.all_points[:,:10000]
|
self.train_points = self.all_points[:, :10000]
|
||||||
self.test_points = self.all_points[:,10000:]
|
self.test_points = self.all_points[:, 10000:]
|
||||||
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
|
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
|
||||||
|
|
||||||
def get_pc_stats(self, idx):
|
def get_pc_stats(self, idx):
|
||||||
|
return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1)
|
||||||
return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Returns the length of the dataset. """
|
"""Returns the length of the dataset."""
|
||||||
return len(self.all_points)
|
return len(self.all_points)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
|
|
||||||
tr_out = self.train_points[index]
|
tr_out = self.train_points[index]
|
||||||
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
|
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
|
||||||
tr_out = tr_out[tr_idxs, :]
|
tr_out = tr_out[tr_idxs, :]
|
||||||
|
|
||||||
gt_points = self.test_points[index][:self.npoints]
|
gt_points = self.test_points[index][: self.npoints]
|
||||||
|
|
||||||
m, s = self.get_pc_stats(index)
|
m, s = self.get_pc_stats(index)
|
||||||
|
|
||||||
sv_points = self.all_points_sv[index]
|
sv_points = self.all_points_sv[index]
|
||||||
|
|
||||||
idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False)
|
idxs = np.arange(0, sv_points.shape[-2])[
|
||||||
|
: self.sv_samples
|
||||||
|
] # np.random.choice(sv_points.shape[0], 500, replace=False)
|
||||||
|
|
||||||
data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(),
|
data = torch.cat(
|
||||||
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1)
|
[
|
||||||
|
torch.from_numpy(sv_points[:, idxs]).float(),
|
||||||
|
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
masks = torch.zeros_like(data)
|
masks = torch.zeros_like(data)
|
||||||
masks[:,:idxs.shape[0]] = 1
|
masks[:, : idxs.shape[0]] = 1
|
||||||
|
|
||||||
res = {'train_points': torch.from_numpy(tr_out).float(),
|
res = {
|
||||||
'test_points': torch.from_numpy(gt_points).float(),
|
"train_points": torch.from_numpy(tr_out).float(),
|
||||||
'sv_points': data,
|
"test_points": torch.from_numpy(gt_points).float(),
|
||||||
'masks': masks,
|
"sv_points": data,
|
||||||
'std': s, 'mean': m,
|
"masks": masks,
|
||||||
'idx': index,
|
"std": s,
|
||||||
'name':self.all_mids[index]
|
"mean": m,
|
||||||
}
|
"idx": index,
|
||||||
|
"name": self.all_mids[index],
|
||||||
if self.split != 'train' and self.get_image:
|
}
|
||||||
|
|
||||||
|
if self.split != "train" and self.get_image:
|
||||||
img_lst = []
|
img_lst = []
|
||||||
for n in range(self.all_points_sv.shape[1]):
|
for n in range(self.all_points_sv.shape[1]):
|
||||||
|
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3]
|
||||||
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]
|
|
||||||
|
|
||||||
img_lst.append(img)
|
img_lst.append(img)
|
||||||
|
|
||||||
img = torch.stack(img_lst, dim=0)
|
img = torch.stack(img_lst, dim=0)
|
||||||
|
|
||||||
res['image'] = img
|
res["image"] = img
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _render(self, cache_path, depth_pth, depth_minmax_pth):
|
def _render(self, cache_path, depth_pth, depth_minmax_pth):
|
||||||
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
|
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
|
||||||
#
|
#
|
||||||
|
@ -210,11 +261,9 @@ class ShapeNet_Multiview_Points(Dataset):
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
data = np.load(cache_path)
|
data = np.load(cache_path)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
data, depth = self.transform(depth_pth, depth_minmax_pth)
|
data, depth = self.transform(depth_pth, depth_minmax_pth)
|
||||||
assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0])
|
assert data.shape[0] > 600, "Only {} points found".format(data.shape[0])
|
||||||
data = data[np.random.choice(data.shape[0], 600, replace=False)]
|
data = data[np.random.choice(data.shape[0], 600, replace=False)]
|
||||||
np.save(cache_path, data)
|
np.save(cache_path, data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,32 @@
|
||||||
from torch import nn
|
|
||||||
from torch.autograd import Function
|
|
||||||
import torch
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Function
|
||||||
|
|
||||||
chamfer_found = importlib.find_loader("chamfer_2D") is not None
|
chamfer_found = importlib.find_loader("chamfer_2D") is not None
|
||||||
if not chamfer_found:
|
if not chamfer_found:
|
||||||
## Cool trick from https://github.com/chrdiller
|
## Cool trick from https://github.com/chrdiller
|
||||||
print("Jitting Chamfer 2D")
|
print("Jitting Chamfer 2D")
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
chamfer_2D = load(name="chamfer_2D",
|
|
||||||
sources=[
|
chamfer_2D = load(
|
||||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
name="chamfer_2D",
|
||||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
|
sources=[
|
||||||
])
|
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||||
|
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer2D.cu"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
print("Loaded JIT 2D CUDA chamfer distance")
|
print("Loaded JIT 2D CUDA chamfer distance")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import chamfer_2D
|
import chamfer_2D
|
||||||
|
|
||||||
print("Loaded compiled 2D CUDA chamfer distance")
|
print("Loaded compiled 2D CUDA chamfer distance")
|
||||||
|
|
||||||
|
|
||||||
# Chamfer's distance module @thibaultgroueix
|
# Chamfer's distance module @thibaultgroueix
|
||||||
# GPU tensors only
|
# GPU tensors only
|
||||||
class chamfer_2DFunction(Function):
|
class chamfer_2DFunction(Function):
|
||||||
|
@ -57,9 +64,7 @@ class chamfer_2DFunction(Function):
|
||||||
|
|
||||||
gradxyz1 = gradxyz1.to(device)
|
gradxyz1 = gradxyz1.to(device)
|
||||||
gradxyz2 = gradxyz2.to(device)
|
gradxyz2 = gradxyz2.to(device)
|
||||||
chamfer_2D.backward(
|
chamfer_2D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
||||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
|
||||||
)
|
|
||||||
return gradxyz1, gradxyz2
|
return gradxyz1, gradxyz2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,15 +2,16 @@ from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='chamfer_2D',
|
name="chamfer_2D",
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CUDAExtension('chamfer_2D', [
|
CUDAExtension(
|
||||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
"chamfer_2D",
|
||||||
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
|
[
|
||||||
]),
|
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||||
|
"/".join(__file__.split("/")[:-1] + ["chamfer2D.cu"]),
|
||||||
|
],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
|
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
cmdclass={"build_ext": BuildExtension},
|
||||||
cmdclass={
|
)
|
||||||
'build_ext': BuildExtension
|
|
||||||
})
|
|
||||||
|
|
|
@ -1,25 +1,30 @@
|
||||||
from torch import nn
|
|
||||||
from torch.autograd import Function
|
|
||||||
import torch
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Function
|
||||||
|
|
||||||
chamfer_found = importlib.find_loader("chamfer_3D") is not None
|
chamfer_found = importlib.find_loader("chamfer_3D") is not None
|
||||||
if not chamfer_found:
|
if not chamfer_found:
|
||||||
## Cool trick from https://github.com/chrdiller
|
## Cool trick from https://github.com/chrdiller
|
||||||
print("Jitting Chamfer 3D")
|
print("Jitting Chamfer 3D")
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
chamfer_3D = load(name="chamfer_3D",
|
|
||||||
sources=[
|
|
||||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
|
||||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
|
||||||
],
|
|
||||||
|
|
||||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],)
|
chamfer_3D = load(
|
||||||
|
name="chamfer_3D",
|
||||||
|
sources=[
|
||||||
|
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||||
|
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer3D.cu"]),
|
||||||
|
],
|
||||||
|
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||||
|
)
|
||||||
print("Loaded JIT 3D CUDA chamfer distance")
|
print("Loaded JIT 3D CUDA chamfer distance")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import chamfer_3D
|
import chamfer_3D
|
||||||
|
|
||||||
print("Loaded compiled 3D CUDA chamfer distance")
|
print("Loaded compiled 3D CUDA chamfer distance")
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,9 +65,7 @@ class chamfer_3DFunction(Function):
|
||||||
|
|
||||||
gradxyz1 = gradxyz1.to(device)
|
gradxyz1 = gradxyz1.to(device)
|
||||||
gradxyz2 = gradxyz2.to(device)
|
gradxyz2 = gradxyz2.to(device)
|
||||||
chamfer_3D.backward(
|
chamfer_3D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
||||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
|
||||||
)
|
|
||||||
return gradxyz1, gradxyz2
|
return gradxyz1, gradxyz2
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,4 +77,3 @@ class chamfer_3DDist(nn.Module):
|
||||||
input1 = input1.contiguous()
|
input1 = input1.contiguous()
|
||||||
input2 = input2.contiguous()
|
input2 = input2.contiguous()
|
||||||
return chamfer_3DFunction.apply(input1, input2)
|
return chamfer_3DFunction.apply(input1, input2)
|
||||||
|
|
||||||
|
|
|
@ -2,15 +2,16 @@ from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='chamfer_3D',
|
name="chamfer_3D",
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CUDAExtension('chamfer_3D', [
|
CUDAExtension(
|
||||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
"chamfer_3D",
|
||||||
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
|
[
|
||||||
]),
|
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||||
|
"/".join(__file__.split("/")[:-1] + ["chamfer3D.cu"]),
|
||||||
|
],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
|
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
cmdclass={"build_ext": BuildExtension},
|
||||||
cmdclass={
|
)
|
||||||
'build_ext': BuildExtension
|
|
||||||
})
|
|
||||||
|
|
|
@ -1,24 +1,29 @@
|
||||||
from torch import nn
|
|
||||||
from torch.autograd import Function
|
|
||||||
import torch
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Function
|
||||||
|
|
||||||
chamfer_found = importlib.find_loader("chamfer_5D") is not None
|
chamfer_found = importlib.find_loader("chamfer_5D") is not None
|
||||||
if not chamfer_found:
|
if not chamfer_found:
|
||||||
## Cool trick from https://github.com/chrdiller
|
## Cool trick from https://github.com/chrdiller
|
||||||
print("Jitting Chamfer 5D")
|
print("Jitting Chamfer 5D")
|
||||||
|
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
chamfer_5D = load(name="chamfer_5D",
|
|
||||||
sources=[
|
chamfer_5D = load(
|
||||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
name="chamfer_5D",
|
||||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]),
|
sources=[
|
||||||
])
|
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||||
|
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer5D.cu"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
print("Loaded JIT 5D CUDA chamfer distance")
|
print("Loaded JIT 5D CUDA chamfer distance")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import chamfer_5D
|
import chamfer_5D
|
||||||
|
|
||||||
print("Loaded compiled 5D CUDA chamfer distance")
|
print("Loaded compiled 5D CUDA chamfer distance")
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,9 +64,7 @@ class chamfer_5DFunction(Function):
|
||||||
|
|
||||||
gradxyz1 = gradxyz1.to(device)
|
gradxyz1 = gradxyz1.to(device)
|
||||||
gradxyz2 = gradxyz2.to(device)
|
gradxyz2 = gradxyz2.to(device)
|
||||||
chamfer_5D.backward(
|
chamfer_5D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
||||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
|
||||||
)
|
|
||||||
return gradxyz1, gradxyz2
|
return gradxyz1, gradxyz2
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,15 +2,16 @@ from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='chamfer_5D',
|
name="chamfer_5D",
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CUDAExtension('chamfer_5D', [
|
CUDAExtension(
|
||||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
"chamfer_5D",
|
||||||
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']),
|
[
|
||||||
]),
|
"/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
|
||||||
|
"/".join(__file__.split("/")[:-1] + ["chamfer5D.cu"]),
|
||||||
|
],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
|
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
cmdclass={"build_ext": BuildExtension},
|
||||||
cmdclass={
|
)
|
||||||
'build_ext': BuildExtension
|
|
||||||
})
|
|
||||||
|
|
|
@ -33,8 +33,7 @@ def distChamfer(a, b):
|
||||||
xx = torch.pow(x, 2).sum(2)
|
xx = torch.pow(x, 2).sum(2)
|
||||||
yy = torch.pow(y, 2).sum(2)
|
yy = torch.pow(y, 2).sum(2)
|
||||||
zz = torch.bmm(x, y.transpose(2, 1))
|
zz = torch.bmm(x, y.transpose(2, 1))
|
||||||
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
|
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
|
||||||
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
|
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
|
||||||
P = rx.transpose(2, 1) + ry - 2 * zz
|
P = rx.transpose(2, 1) + ry - 2 * zz
|
||||||
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()
|
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def fscore(dist1, dist2, threshold=0.001):
|
def fscore(dist1, dist2, threshold=0.001):
|
||||||
"""
|
"""
|
||||||
Calculates the F-score between two point clouds with the corresponding threshold value.
|
Calculates the F-score between two point clouds with the corresponding threshold value.
|
||||||
|
@ -14,4 +15,3 @@ def fscore(dist1, dist2, threshold=0.001):
|
||||||
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
|
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
|
||||||
fscore[torch.isnan(fscore)] = 0
|
fscore[torch.isnan(fscore)] = 0
|
||||||
return fscore, precision_1, precision_2
|
return fscore, precision_1, precision_2
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,23 @@
|
||||||
import torch, time
|
import time
|
||||||
|
|
||||||
import chamfer2D.dist_chamfer_2D
|
import chamfer2D.dist_chamfer_2D
|
||||||
import chamfer3D.dist_chamfer_3D
|
import chamfer3D.dist_chamfer_3D
|
||||||
import chamfer5D.dist_chamfer_5D
|
import chamfer5D.dist_chamfer_5D
|
||||||
import chamfer_python
|
import chamfer_python
|
||||||
|
import torch
|
||||||
|
|
||||||
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
|
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
|
||||||
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
|
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
|
||||||
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
|
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
|
||||||
|
|
||||||
from torch.autograd import Variable
|
|
||||||
from fscore import fscore
|
from fscore import fscore
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
def test_chamfer(distChamfer, dim):
|
def test_chamfer(distChamfer, dim):
|
||||||
points1 = torch.rand(4, 100, dim).cuda()
|
points1 = torch.rand(4, 100, dim).cuda()
|
||||||
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
|
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
|
||||||
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
|
dist1, dist2, idx1, idx2 = distChamfer(points1, points2)
|
||||||
|
|
||||||
loss = torch.sum(dist1)
|
loss = torch.sum(dist1)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -29,9 +32,9 @@ def test_chamfer(distChamfer, dim):
|
||||||
xd1 = idx1 - myidx1
|
xd1 = idx1 - myidx1
|
||||||
xd2 = idx2 - myidx2
|
xd2 = idx2 - myidx2
|
||||||
assert (
|
assert (
|
||||||
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
|
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
|
||||||
), "chamfer cuda and chamfer normal are not giving the same results"
|
), "chamfer cuda and chamfer normal are not giving the same results"
|
||||||
print(f"fscore :", fscore(dist1, dist2))
|
print("fscore :", fscore(dist1, dist2))
|
||||||
print("Unit test passed")
|
print("Unit test passed")
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +52,6 @@ def timings(distChamfer, dim):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
||||||
|
|
||||||
|
|
||||||
print("Timings : Start Pythonic version")
|
print("Timings : Start Pythonic version")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for i in range(num_it):
|
for i in range(num_it):
|
||||||
|
@ -61,9 +63,8 @@ def timings(distChamfer, dim):
|
||||||
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
||||||
|
|
||||||
|
|
||||||
|
dims = [2, 3, 5]
|
||||||
dims = [2,3,5]
|
for i, cham in enumerate([cham2D, cham3D, cham5D]):
|
||||||
for i,cham in enumerate([cham2D, cham3D, cham5D]):
|
|
||||||
print(f"testing Chamfer {dims[i]}D")
|
print(f"testing Chamfer {dims[i]}D")
|
||||||
test_chamfer(cham, dims[i])
|
test_chamfer(cham, dims[i])
|
||||||
timings(cham, dims[i])
|
timings(cham, dims[i])
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
|
||||||
import emd_cuda
|
import emd_cuda
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||||
|
@ -44,4 +44,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
|
||||||
cost = cost / xyz1.shape[1]
|
cost = cost / xyz1.shape[1]
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
|
@ -9,19 +9,17 @@ Notes:
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='emd_ext',
|
name="emd_ext",
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name='emd_cuda',
|
name="emd_cuda",
|
||||||
sources=[
|
sources=[
|
||||||
'cuda/emd.cpp',
|
"cuda/emd.cpp",
|
||||||
'cuda/emd_kernel.cu',
|
"cuda/emd_kernel.cu",
|
||||||
],
|
],
|
||||||
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
cmdclass={
|
cmdclass={"build_ext": BuildExtension},
|
||||||
'build_ext': BuildExtension
|
)
|
||||||
})
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import torch
|
||||||
from emd import earth_mover_distance
|
from emd import earth_mover_distance
|
||||||
|
|
||||||
# gt
|
# gt
|
||||||
|
@ -13,10 +12,12 @@ print(p2)
|
||||||
p1.requires_grad = True
|
p1.requires_grad = True
|
||||||
p2.requires_grad = True
|
p2.requires_grad = True
|
||||||
|
|
||||||
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
gt_dist = (
|
||||||
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
(((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
|
||||||
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
+ (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
|
||||||
print('gt_dist: ', gt_dist)
|
+ (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
|
||||||
|
)
|
||||||
|
print("gt_dist: ", gt_dist)
|
||||||
|
|
||||||
gt_dist.backward()
|
gt_dist.backward()
|
||||||
print(p1.grad)
|
print(p1.grad)
|
||||||
|
@ -41,4 +42,3 @@ print(loss)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
print(p1.grad)
|
print(p1.grad)
|
||||||
print(p2.grad)
|
print(p2.grad)
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from numpy.linalg import norm
|
||||||
from scipy.stats import entropy
|
from scipy.stats import entropy
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors
|
||||||
from numpy.linalg import norm
|
|
||||||
|
|
||||||
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
|
|
||||||
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
|
|
||||||
from metrics.ChamferDistancePytorch.fscore import fscore
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
|
||||||
|
from metrics.ChamferDistancePytorch.fscore import fscore
|
||||||
|
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
|
||||||
|
|
||||||
cham3D = chamfer_3DDist()
|
cham3D = chamfer_3DDist()
|
||||||
|
|
||||||
|
|
||||||
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
|
||||||
def distChamfer(a, b):
|
def distChamfer(a, b):
|
||||||
x, y = a, b
|
x, y = a, b
|
||||||
|
@ -22,11 +24,11 @@ def distChamfer(a, b):
|
||||||
diag_ind = torch.arange(0, num_points).to(a).long()
|
diag_ind = torch.arange(0, num_points).to(a).long()
|
||||||
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
|
||||||
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
|
||||||
P = (rx.transpose(2, 1) + ry - 2 * zz)
|
P = rx.transpose(2, 1) + ry - 2 * zz
|
||||||
return P.min(1)[0], P.min(2)[0]
|
return P.min(1)[0], P.min(2)[0]
|
||||||
|
|
||||||
|
|
||||||
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
|
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
|
||||||
N_sample = sample_pcs.shape[0]
|
N_sample = sample_pcs.shape[0]
|
||||||
N_ref = ref_pcs.shape[0]
|
N_ref = ref_pcs.shape[0]
|
||||||
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
|
||||||
|
@ -56,13 +58,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
|
||||||
cd = torch.cat(cd_lst)
|
cd = torch.cat(cd_lst)
|
||||||
emd = torch.cat(emd_lst)
|
emd = torch.cat(emd_lst)
|
||||||
fs_lst = torch.cat(fs_lst).mean()
|
fs_lst = torch.cat(fs_lst).mean()
|
||||||
results = {
|
results = {"MMD-CD": cd, "MMD-EMD": emd, "fscore": fs_lst}
|
||||||
'MMD-CD': cd,
|
|
||||||
'MMD-EMD': emd,
|
|
||||||
'fscore': fs_lst
|
|
||||||
}
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
|
||||||
N_sample = sample_pcs.shape[0]
|
N_sample = sample_pcs.shape[0]
|
||||||
N_ref = ref_pcs.shape[0]
|
N_ref = ref_pcs.shape[0]
|
||||||
|
@ -107,7 +106,7 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
|
||||||
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
|
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
|
||||||
if sqrt:
|
if sqrt:
|
||||||
M = M.abs().sqrt()
|
M = M.abs().sqrt()
|
||||||
INFINITY = float('inf')
|
INFINITY = float("inf")
|
||||||
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
|
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
|
||||||
|
|
||||||
count = torch.zeros(n0 + n1).to(Mxx)
|
count = torch.zeros(n0 + n1).to(Mxx)
|
||||||
|
@ -116,19 +115,21 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
|
||||||
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
|
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
|
||||||
|
|
||||||
s = {
|
s = {
|
||||||
'tp': (pred * label).sum(),
|
"tp": (pred * label).sum(),
|
||||||
'fp': (pred * (1 - label)).sum(),
|
"fp": (pred * (1 - label)).sum(),
|
||||||
'fn': ((1 - pred) * label).sum(),
|
"fn": ((1 - pred) * label).sum(),
|
||||||
'tn': ((1 - pred) * (1 - label)).sum(),
|
"tn": ((1 - pred) * (1 - label)).sum(),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.update({
|
s.update(
|
||||||
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
|
{
|
||||||
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
|
"precision": s["tp"] / (s["tp"] + s["fp"] + 1e-10),
|
||||||
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
|
"recall": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
|
||||||
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
|
"acc_t": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
|
||||||
'acc': torch.eq(label, pred).float().mean(),
|
"acc_f": s["tn"] / (s["tn"] + s["fp"] + 1e-10),
|
||||||
})
|
"acc": torch.eq(label, pred).float().mean(),
|
||||||
|
}
|
||||||
|
)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,9 +142,9 @@ def lgan_mmd_cov(all_dist):
|
||||||
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
|
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
|
||||||
cov = torch.tensor(cov).to(all_dist)
|
cov = torch.tensor(cov).to(all_dist)
|
||||||
return {
|
return {
|
||||||
'lgan_mmd': mmd,
|
"lgan_mmd": mmd,
|
||||||
'lgan_cov': cov,
|
"lgan_cov": cov,
|
||||||
'lgan_mmd_smp': mmd_smp,
|
"lgan_mmd_smp": mmd_smp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,27 +154,19 @@ def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
|
||||||
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
|
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
|
||||||
|
|
||||||
res_cd = lgan_mmd_cov(M_rs_cd.t())
|
res_cd = lgan_mmd_cov(M_rs_cd.t())
|
||||||
results.update({
|
results.update({"%s-CD" % k: v for k, v in res_cd.items()})
|
||||||
"%s-CD" % k: v for k, v in res_cd.items()
|
|
||||||
})
|
|
||||||
|
|
||||||
res_emd = lgan_mmd_cov(M_rs_emd.t())
|
res_emd = lgan_mmd_cov(M_rs_emd.t())
|
||||||
results.update({
|
results.update({"%s-EMD" % k: v for k, v in res_emd.items()})
|
||||||
"%s-EMD" % k: v for k, v in res_emd.items()
|
|
||||||
})
|
|
||||||
|
|
||||||
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
|
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
|
||||||
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
|
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
|
||||||
|
|
||||||
# 1-NN results
|
# 1-NN results
|
||||||
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
|
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
|
||||||
results.update({
|
results.update({"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if "acc" in k})
|
||||||
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
|
|
||||||
})
|
|
||||||
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
|
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
|
||||||
results.update({
|
results.update({"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if "acc" in k})
|
||||||
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -227,11 +220,11 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
|
||||||
bound = 0.5 + epsilon
|
bound = 0.5 + epsilon
|
||||||
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
||||||
if verbose:
|
if verbose:
|
||||||
warnings.warn('Point-clouds are not in unit cube.')
|
warnings.warn("Point-clouds are not in unit cube.")
|
||||||
|
|
||||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
|
||||||
if verbose:
|
if verbose:
|
||||||
warnings.warn('Point-clouds are not in unit sphere.')
|
warnings.warn("Point-clouds are not in unit sphere.")
|
||||||
|
|
||||||
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
||||||
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
||||||
|
@ -260,9 +253,9 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
|
||||||
|
|
||||||
def jensen_shannon_divergence(P, Q):
|
def jensen_shannon_divergence(P, Q):
|
||||||
if np.any(P < 0) or np.any(Q < 0):
|
if np.any(P < 0) or np.any(Q < 0):
|
||||||
raise ValueError('Negative values.')
|
raise ValueError("Negative values.")
|
||||||
if len(P) != len(Q):
|
if len(P) != len(Q):
|
||||||
raise ValueError('Non equal size.')
|
raise ValueError("Non equal size.")
|
||||||
|
|
||||||
P_ = P / np.sum(P) # Ensure probabilities.
|
P_ = P / np.sum(P) # Ensure probabilities.
|
||||||
Q_ = Q / np.sum(Q)
|
Q_ = Q / np.sum(Q)
|
||||||
|
@ -275,7 +268,7 @@ def jensen_shannon_divergence(P, Q):
|
||||||
res2 = _jsdiv(P_, Q_)
|
res2 = _jsdiv(P_, Q_)
|
||||||
|
|
||||||
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
||||||
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
warnings.warn("Numerical values of two JSD methods don't agree.")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -312,11 +305,9 @@ if __name__ == "__main__":
|
||||||
r_dist = min_r.mean().cpu().detach().item()
|
r_dist = min_r.mean().cpu().detach().item()
|
||||||
print(l_dist, r_dist)
|
print(l_dist, r_dist)
|
||||||
|
|
||||||
|
|
||||||
emd_batch = EMD(x.cuda(), y.cuda(), False)
|
emd_batch = EMD(x.cuda(), y.cuda(), False)
|
||||||
print(emd_batch.shape)
|
print(emd_batch.shape)
|
||||||
print(emd_batch.mean().detach().item())
|
print(emd_batch.mean().detach().item())
|
||||||
|
|
||||||
jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy())
|
jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy())
|
||||||
print(jsd)
|
print(jsd)
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
|
||||||
|
|
||||||
|
|
||||||
def _linear_gn_relu(in_channels, out_channels):
|
def _linear_gn_relu(in_channels, out_channels):
|
||||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
|
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
|
||||||
|
|
||||||
|
|
||||||
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
||||||
|
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
|
||||||
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
||||||
|
|
||||||
|
|
||||||
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
|
def create_pointnet_components(
|
||||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
blocks,
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
with_se=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||||
|
|
||||||
layers, concat_channels = [], 0
|
layers, concat_channels = [], 0
|
||||||
|
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
|
||||||
if voxel_resolution is None:
|
if voxel_resolution is None:
|
||||||
block = SharedMLP
|
block = SharedMLP
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
block = functools.partial(
|
||||||
with_se=with_se, normalize=normalize, eps=eps)
|
PVConv,
|
||||||
|
kernel_size=3,
|
||||||
|
resolution=int(vr * voxel_resolution),
|
||||||
|
attention=attention,
|
||||||
|
with_se=with_se,
|
||||||
|
normalize=normalize,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
if c == 0:
|
if c == 0:
|
||||||
layers.append(block(in_channels, out_channels))
|
layers.append(block(in_channels, out_channels))
|
||||||
else:
|
else:
|
||||||
layers.append(block(in_channels+embed_dim, out_channels))
|
layers.append(block(in_channels + embed_dim, out_channels))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
concat_channels += out_channels
|
concat_channels += out_channels
|
||||||
c += 1
|
c += 1
|
||||||
return layers, in_channels, concat_channels
|
return layers, in_channels, concat_channels
|
||||||
|
|
||||||
|
|
||||||
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
|
def create_pointnet2_sa_components(
|
||||||
dropout=0.1, with_se=False, normalize=True, eps=0,
|
sa_blocks,
|
||||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
extra_feature_channels,
|
||||||
|
embed_dim=64,
|
||||||
|
use_att=False,
|
||||||
|
dropout=0.1,
|
||||||
|
with_se=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||||
in_channels = extra_feature_channels + 3
|
in_channels = extra_feature_channels + 3
|
||||||
|
|
||||||
|
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
||||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||||
out_channels = int(r * out_channels)
|
out_channels = int(r * out_channels)
|
||||||
for p in range(num_blocks):
|
for p in range(num_blocks):
|
||||||
attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0
|
attention = (c + 1) % 2 == 0 and c > 0 and use_att and p == 0
|
||||||
if voxel_resolution is None:
|
if voxel_resolution is None:
|
||||||
block = SharedMLP
|
block = SharedMLP
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
block = functools.partial(
|
||||||
dropout=dropout,
|
PVConv,
|
||||||
with_se=with_se and not attention, with_se_relu=True,
|
kernel_size=3,
|
||||||
normalize=normalize, eps=eps)
|
resolution=int(vr * voxel_resolution),
|
||||||
|
attention=attention,
|
||||||
|
dropout=dropout,
|
||||||
|
with_se=with_se and not attention,
|
||||||
|
with_se_relu=True,
|
||||||
|
normalize=normalize,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
if c == 0:
|
if c == 0:
|
||||||
sa_blocks.append(block(in_channels, out_channels))
|
sa_blocks.append(block(in_channels, out_channels))
|
||||||
elif k ==0:
|
elif k == 0:
|
||||||
sa_blocks.append(block(in_channels+embed_dim, out_channels))
|
sa_blocks.append(block(in_channels + embed_dim, out_channels))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
k += 1
|
k += 1
|
||||||
extra_feature_channels = in_channels
|
extra_feature_channels = in_channels
|
||||||
|
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
||||||
if num_centers is None:
|
if num_centers is None:
|
||||||
block = PointNetAModule
|
block = PointNetAModule
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
|
block = functools.partial(
|
||||||
num_neighbors=num_neighbors)
|
PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
|
||||||
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
|
)
|
||||||
include_coordinates=True))
|
sa_blocks.append(
|
||||||
|
block(
|
||||||
|
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
|
||||||
|
out_channels=out_channels,
|
||||||
|
include_coordinates=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
c += 1
|
c += 1
|
||||||
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
||||||
if len(sa_blocks) == 1:
|
if len(sa_blocks) == 1:
|
||||||
|
@ -127,10 +165,20 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
||||||
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
||||||
|
|
||||||
|
|
||||||
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False,
|
def create_pointnet2_fp_modules(
|
||||||
dropout=0.1,
|
fp_blocks,
|
||||||
with_se=False, normalize=True, eps=0,
|
in_channels,
|
||||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
sa_in_channels,
|
||||||
|
sv_points,
|
||||||
|
embed_dim=64,
|
||||||
|
use_att=False,
|
||||||
|
dropout=0.1,
|
||||||
|
with_se=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||||
|
|
||||||
fp_layers = []
|
fp_layers = []
|
||||||
|
@ -139,7 +187,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
|
||||||
fp_blocks = []
|
fp_blocks = []
|
||||||
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
||||||
fp_blocks.append(
|
fp_blocks.append(
|
||||||
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
|
PointNetFPModule(
|
||||||
|
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
|
||||||
|
)
|
||||||
)
|
)
|
||||||
in_channels = out_channels[-1]
|
in_channels = out_channels[-1]
|
||||||
|
|
||||||
|
@ -151,9 +201,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
|
||||||
if voxel_resolution is None:
|
if voxel_resolution is None:
|
||||||
block = SharedMLP
|
block = SharedMLP
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
block = functools.partial(
|
||||||
dropout=dropout,
|
PVConv,
|
||||||
with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps)
|
kernel_size=3,
|
||||||
|
resolution=int(vr * voxel_resolution),
|
||||||
|
attention=attention,
|
||||||
|
dropout=dropout,
|
||||||
|
with_se=with_se and not attention,
|
||||||
|
with_se_relu=True,
|
||||||
|
normalize=normalize,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
fp_blocks.append(block(in_channels, out_channels))
|
fp_blocks.append(block(in_channels, out_channels))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
|
@ -168,9 +226,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
|
||||||
|
|
||||||
|
|
||||||
class PVCNN2Base(nn.Module):
|
class PVCNN2Base(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1,
|
self,
|
||||||
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
|
num_classes,
|
||||||
|
sv_points,
|
||||||
|
embed_dim,
|
||||||
|
use_att,
|
||||||
|
dropout=0.1,
|
||||||
|
extra_feature_channels=3,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert extra_feature_channels >= 0
|
assert extra_feature_channels >= 0
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
@ -178,9 +244,14 @@ class PVCNN2Base(nn.Module):
|
||||||
self.in_channels = extra_feature_channels + 3
|
self.in_channels = extra_feature_channels + 3
|
||||||
|
|
||||||
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
|
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
|
||||||
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
|
sa_blocks=self.sa_blocks,
|
||||||
use_att=use_att, dropout=dropout,
|
extra_feature_channels=extra_feature_channels,
|
||||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
with_se=True,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
use_att=use_att,
|
||||||
|
dropout=dropout,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||||
)
|
)
|
||||||
self.sa_layers = nn.ModuleList(sa_layers)
|
self.sa_layers = nn.ModuleList(sa_layers)
|
||||||
|
|
||||||
|
@ -189,16 +260,26 @@ class PVCNN2Base(nn.Module):
|
||||||
# only use extra features in the last fp module
|
# only use extra features in the last fp module
|
||||||
sa_in_channels[0] = extra_feature_channels
|
sa_in_channels[0] = extra_feature_channels
|
||||||
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
||||||
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points,
|
fp_blocks=self.fp_blocks,
|
||||||
with_se=True, embed_dim=embed_dim,
|
in_channels=channels_sa_features,
|
||||||
use_att=use_att, dropout=dropout,
|
sa_in_channels=sa_in_channels,
|
||||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
sv_points=sv_points,
|
||||||
|
with_se=True,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
use_att=use_att,
|
||||||
|
dropout=dropout,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||||
)
|
)
|
||||||
self.fp_layers = nn.ModuleList(fp_layers)
|
self.fp_layers = nn.ModuleList(fp_layers)
|
||||||
|
|
||||||
|
layers, _ = create_mlp_components(
|
||||||
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes],
|
in_channels=channels_fp_features,
|
||||||
classifier=True, dim=2, width_multiplier=width_multiplier)
|
out_channels=[128, 0.5, num_classes],
|
||||||
|
classifier=True,
|
||||||
|
dim=2,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
)
|
||||||
self.classifier = nn.Sequential(*layers)
|
self.classifier = nn.Sequential(*layers)
|
||||||
|
|
||||||
self.embedf = nn.Sequential(
|
self.embedf = nn.Sequential(
|
||||||
|
@ -223,31 +304,30 @@ class PVCNN2Base(nn.Module):
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
def forward(self, inputs, t):
|
def forward(self, inputs, t):
|
||||||
|
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
|
||||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
|
|
||||||
|
|
||||||
# inputs : [B, in_channels + S, N]
|
# inputs : [B, in_channels + S, N]
|
||||||
coords, features = inputs[:, :3, :].contiguous(), inputs
|
coords, features = inputs[:, :3, :].contiguous(), inputs
|
||||||
coords_list, in_features_list = [], []
|
coords_list, in_features_list = [], []
|
||||||
for i, sa_blocks in enumerate(self.sa_layers):
|
for i, sa_blocks in enumerate(self.sa_layers):
|
||||||
in_features_list.append(features)
|
in_features_list.append(features)
|
||||||
coords_list.append(coords)
|
coords_list.append(coords)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
features, coords, temb = sa_blocks ((features, coords, temb))
|
features, coords, temb = sa_blocks((features, coords, temb))
|
||||||
else:
|
else:
|
||||||
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
|
features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
|
||||||
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
||||||
if self.global_att is not None:
|
if self.global_att is not None:
|
||||||
features = self.global_att(features)
|
features = self.global_att(features)
|
||||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||||
jump_coords = coords_list[-1 - fp_idx]
|
jump_coords = coords_list[-1 - fp_idx]
|
||||||
fump_feats = in_features_list[-1-fp_idx]
|
fump_feats = in_features_list[-1 - fp_idx]
|
||||||
# if fp_idx == len(self.fp_layers) - 1:
|
# if fp_idx == len(self.fp_layers) - 1:
|
||||||
# jump_coords = jump_coords[:,:,self.sv_points:]
|
# jump_coords = jump_coords[:,:,self.sv_points:]
|
||||||
# fump_feats = fump_feats[:,:,self.sv_points:]
|
# fump_feats = fump_feats[:,:,self.sv_points:]
|
||||||
|
|
||||||
features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb))
|
features, coords, temb = fp_blocks(
|
||||||
|
(jump_coords, coords, torch.cat([features, temb], dim=1), fump_feats, temb)
|
||||||
|
)
|
||||||
|
|
||||||
return self.classifier(features)
|
return self.classifier(features)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
|
||||||
|
|
||||||
|
|
||||||
def _linear_gn_relu(in_channels, out_channels):
|
def _linear_gn_relu(in_channels, out_channels):
|
||||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
|
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
|
||||||
|
|
||||||
|
|
||||||
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
||||||
|
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
|
||||||
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
||||||
|
|
||||||
|
|
||||||
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
|
def create_pointnet_components(
|
||||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
blocks,
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
with_se=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||||
|
|
||||||
layers, concat_channels = [], 0
|
layers, concat_channels = [], 0
|
||||||
|
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
|
||||||
if voxel_resolution is None:
|
if voxel_resolution is None:
|
||||||
block = SharedMLP
|
block = SharedMLP
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
block = functools.partial(
|
||||||
with_se=with_se, normalize=normalize, eps=eps)
|
PVConv,
|
||||||
|
kernel_size=3,
|
||||||
|
resolution=int(vr * voxel_resolution),
|
||||||
|
attention=attention,
|
||||||
|
with_se=with_se,
|
||||||
|
normalize=normalize,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
if c == 0:
|
if c == 0:
|
||||||
layers.append(block(in_channels, out_channels))
|
layers.append(block(in_channels, out_channels))
|
||||||
else:
|
else:
|
||||||
layers.append(block(in_channels+embed_dim, out_channels))
|
layers.append(block(in_channels + embed_dim, out_channels))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
concat_channels += out_channels
|
concat_channels += out_channels
|
||||||
c += 1
|
c += 1
|
||||||
return layers, in_channels, concat_channels
|
return layers, in_channels, concat_channels
|
||||||
|
|
||||||
|
|
||||||
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False,
|
def create_pointnet2_sa_components(
|
||||||
dropout=0.1, with_se=False, normalize=True, eps=0,
|
sa_blocks,
|
||||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
extra_feature_channels,
|
||||||
|
embed_dim=64,
|
||||||
|
use_att=False,
|
||||||
|
dropout=0.1,
|
||||||
|
with_se=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||||
in_channels = extra_feature_channels + 3
|
in_channels = extra_feature_channels + 3
|
||||||
|
|
||||||
|
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
||||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||||
out_channels = int(r * out_channels)
|
out_channels = int(r * out_channels)
|
||||||
for p in range(num_blocks):
|
for p in range(num_blocks):
|
||||||
attention = (c+1) % 2 == 0 and use_att and p == 0
|
attention = (c + 1) % 2 == 0 and use_att and p == 0
|
||||||
if voxel_resolution is None:
|
if voxel_resolution is None:
|
||||||
block = SharedMLP
|
block = SharedMLP
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
block = functools.partial(
|
||||||
dropout=dropout,
|
PVConv,
|
||||||
with_se=with_se, with_se_relu=True,
|
kernel_size=3,
|
||||||
normalize=normalize, eps=eps)
|
resolution=int(vr * voxel_resolution),
|
||||||
|
attention=attention,
|
||||||
|
dropout=dropout,
|
||||||
|
with_se=with_se,
|
||||||
|
with_se_relu=True,
|
||||||
|
normalize=normalize,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
if c == 0:
|
if c == 0:
|
||||||
sa_blocks.append(block(in_channels, out_channels))
|
sa_blocks.append(block(in_channels, out_channels))
|
||||||
elif k ==0:
|
elif k == 0:
|
||||||
sa_blocks.append(block(in_channels+embed_dim, out_channels))
|
sa_blocks.append(block(in_channels + embed_dim, out_channels))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
k += 1
|
k += 1
|
||||||
extra_feature_channels = in_channels
|
extra_feature_channels = in_channels
|
||||||
|
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
||||||
if num_centers is None:
|
if num_centers is None:
|
||||||
block = PointNetAModule
|
block = PointNetAModule
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
|
block = functools.partial(
|
||||||
num_neighbors=num_neighbors)
|
PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
|
||||||
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
|
)
|
||||||
include_coordinates=True))
|
sa_blocks.append(
|
||||||
|
block(
|
||||||
|
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
|
||||||
|
out_channels=out_channels,
|
||||||
|
include_coordinates=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
c += 1
|
c += 1
|
||||||
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
||||||
if len(sa_blocks) == 1:
|
if len(sa_blocks) == 1:
|
||||||
|
@ -127,10 +165,19 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
|
||||||
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
||||||
|
|
||||||
|
|
||||||
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
|
def create_pointnet2_fp_modules(
|
||||||
dropout=0.1,
|
fp_blocks,
|
||||||
with_se=False, normalize=True, eps=0,
|
in_channels,
|
||||||
width_multiplier=1, voxel_resolution_multiplier=1):
|
sa_in_channels,
|
||||||
|
embed_dim=64,
|
||||||
|
use_att=False,
|
||||||
|
dropout=0.1,
|
||||||
|
with_se=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||||
|
|
||||||
fp_layers = []
|
fp_layers = []
|
||||||
|
@ -139,7 +186,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
|
||||||
fp_blocks = []
|
fp_blocks = []
|
||||||
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
||||||
fp_blocks.append(
|
fp_blocks.append(
|
||||||
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels)
|
PointNetFPModule(
|
||||||
|
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
|
||||||
|
)
|
||||||
)
|
)
|
||||||
in_channels = out_channels[-1]
|
in_channels = out_channels[-1]
|
||||||
|
|
||||||
|
@ -147,14 +196,21 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
|
||||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||||
out_channels = int(r * out_channels)
|
out_channels = int(r * out_channels)
|
||||||
for p in range(num_blocks):
|
for p in range(num_blocks):
|
||||||
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
|
attention = (c + 1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
|
||||||
if voxel_resolution is None:
|
if voxel_resolution is None:
|
||||||
block = SharedMLP
|
block = SharedMLP
|
||||||
else:
|
else:
|
||||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
block = functools.partial(
|
||||||
dropout=dropout,
|
PVConv,
|
||||||
with_se=with_se, with_se_relu=True,
|
kernel_size=3,
|
||||||
normalize=normalize, eps=eps)
|
resolution=int(vr * voxel_resolution),
|
||||||
|
attention=attention,
|
||||||
|
dropout=dropout,
|
||||||
|
with_se=with_se,
|
||||||
|
with_se_relu=True,
|
||||||
|
normalize=normalize,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
fp_blocks.append(block(in_channels, out_channels))
|
fp_blocks.append(block(in_channels, out_channels))
|
||||||
in_channels = out_channels
|
in_channels = out_channels
|
||||||
|
@ -168,20 +224,31 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
|
||||||
return fp_layers, in_channels
|
return fp_layers, in_channels
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PVCNN2Base(nn.Module):
|
class PVCNN2Base(nn.Module):
|
||||||
|
def __init__(
|
||||||
def __init__(self, num_classes, embed_dim, use_att, dropout=0.1,
|
self,
|
||||||
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1):
|
num_classes,
|
||||||
|
embed_dim,
|
||||||
|
use_att,
|
||||||
|
dropout=0.1,
|
||||||
|
extra_feature_channels=3,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert extra_feature_channels >= 0
|
assert extra_feature_channels >= 0
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.in_channels = extra_feature_channels + 3
|
self.in_channels = extra_feature_channels + 3
|
||||||
|
|
||||||
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
|
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
|
||||||
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim,
|
sa_blocks=self.sa_blocks,
|
||||||
use_att=use_att, dropout=dropout,
|
extra_feature_channels=extra_feature_channels,
|
||||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
with_se=True,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
use_att=use_att,
|
||||||
|
dropout=dropout,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||||
)
|
)
|
||||||
self.sa_layers = nn.ModuleList(sa_layers)
|
self.sa_layers = nn.ModuleList(sa_layers)
|
||||||
|
|
||||||
|
@ -190,15 +257,25 @@ class PVCNN2Base(nn.Module):
|
||||||
# only use extra features in the last fp module
|
# only use extra features in the last fp module
|
||||||
sa_in_channels[0] = extra_feature_channels
|
sa_in_channels[0] = extra_feature_channels
|
||||||
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
||||||
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim,
|
fp_blocks=self.fp_blocks,
|
||||||
use_att=use_att, dropout=dropout,
|
in_channels=channels_sa_features,
|
||||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
sa_in_channels=sa_in_channels,
|
||||||
|
with_se=True,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
use_att=use_att,
|
||||||
|
dropout=dropout,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||||
)
|
)
|
||||||
self.fp_layers = nn.ModuleList(fp_layers)
|
self.fp_layers = nn.ModuleList(fp_layers)
|
||||||
|
|
||||||
|
layers, _ = create_mlp_components(
|
||||||
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5
|
in_channels=channels_fp_features,
|
||||||
classifier=True, dim=2, width_multiplier=width_multiplier)
|
out_channels=[128, dropout, num_classes], # was 0.5
|
||||||
|
classifier=True,
|
||||||
|
dim=2,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
)
|
||||||
self.classifier = nn.Sequential(*layers)
|
self.classifier = nn.Sequential(*layers)
|
||||||
|
|
||||||
self.embedf = nn.Sequential(
|
self.embedf = nn.Sequential(
|
||||||
|
@ -223,25 +300,30 @@ class PVCNN2Base(nn.Module):
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
def forward(self, inputs, t):
|
def forward(self, inputs, t):
|
||||||
|
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
|
||||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
|
|
||||||
|
|
||||||
# inputs : [B, in_channels + S, N]
|
# inputs : [B, in_channels + S, N]
|
||||||
coords, features = inputs[:, :3, :].contiguous(), inputs
|
coords, features = inputs[:, :3, :].contiguous(), inputs
|
||||||
coords_list, in_features_list = [], []
|
coords_list, in_features_list = [], []
|
||||||
for i, sa_blocks in enumerate(self.sa_layers):
|
for i, sa_blocks in enumerate(self.sa_layers):
|
||||||
in_features_list.append(features)
|
in_features_list.append(features)
|
||||||
coords_list.append(coords)
|
coords_list.append(coords)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
features, coords, temb = sa_blocks ((features, coords, temb))
|
features, coords, temb = sa_blocks((features, coords, temb))
|
||||||
else:
|
else:
|
||||||
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb))
|
features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
|
||||||
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
||||||
if self.global_att is not None:
|
if self.global_att is not None:
|
||||||
features = self.global_att(features)
|
features = self.global_att(features)
|
||||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||||
features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb))
|
features, coords, temb = fp_blocks(
|
||||||
|
(
|
||||||
|
coords_list[-1 - fp_idx],
|
||||||
|
coords,
|
||||||
|
torch.cat([features, temb], dim=1),
|
||||||
|
in_features_list[-1 - fp_idx],
|
||||||
|
temb,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return self.classifier(features)
|
return self.classifier(features)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
from modules.ball_query import BallQuery
|
|
||||||
from modules.frustum import FrustumPointNetLoss
|
|
||||||
from modules.loss import KLLoss
|
|
||||||
from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
|
|
||||||
from modules.pvconv import PVConv, Attention, Swish, PVConvReLU
|
|
||||||
from modules.se import SE3d
|
|
||||||
from modules.shared_mlp import SharedMLP
|
|
||||||
from modules.voxelization import Voxelization
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import modules.functional as F
|
import modules.functional as F
|
||||||
|
|
||||||
__all__ = ['BallQuery']
|
__all__ = ["BallQuery"]
|
||||||
|
|
||||||
|
|
||||||
class BallQuery(nn.Module):
|
class BallQuery(nn.Module):
|
||||||
|
@ -21,7 +21,7 @@ class BallQuery(nn.Module):
|
||||||
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
||||||
|
|
||||||
if points_features is None:
|
if points_features is None:
|
||||||
assert self.include_coordinates, 'No Features For Grouping'
|
assert self.include_coordinates, "No Features For Grouping"
|
||||||
neighbor_features = neighbor_coordinates
|
neighbor_features = neighbor_coordinates
|
||||||
else:
|
else:
|
||||||
neighbor_features = F.grouping(points_features, neighbor_indices)
|
neighbor_features = F.grouping(points_features, neighbor_indices)
|
||||||
|
@ -30,5 +30,6 @@ class BallQuery(nn.Module):
|
||||||
return neighbor_features, F.grouping(temb, neighbor_indices)
|
return neighbor_features, F.grouping(temb, neighbor_indices)
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return 'radius={}, num_neighbors={}{}'.format(
|
return "radius={}, num_neighbors={}{}".format(
|
||||||
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
self.radius, self.num_neighbors, ", include coordinates" if self.include_coordinates else ""
|
||||||
|
)
|
||||||
|
|
|
@ -5,12 +5,20 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
import modules.functional as PF
|
import modules.functional as PF
|
||||||
|
|
||||||
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
|
__all__ = ["FrustumPointNetLoss", "get_box_corners_3d"]
|
||||||
|
|
||||||
|
|
||||||
class FrustumPointNetLoss(nn.Module):
|
class FrustumPointNetLoss(nn.Module):
|
||||||
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
|
def __init__(
|
||||||
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
|
self,
|
||||||
|
num_heading_angle_bins,
|
||||||
|
num_size_templates,
|
||||||
|
size_templates,
|
||||||
|
box_loss_weight=1.0,
|
||||||
|
corners_loss_weight=10.0,
|
||||||
|
heading_residual_loss_weight=20.0,
|
||||||
|
size_residual_loss_weight=20.0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.box_loss_weight = box_loss_weight
|
self.box_loss_weight = box_loss_weight
|
||||||
self.corners_loss_weight = corners_loss_weight
|
self.corners_loss_weight = corners_loss_weight
|
||||||
|
@ -19,28 +27,28 @@ class FrustumPointNetLoss(nn.Module):
|
||||||
|
|
||||||
self.num_heading_angle_bins = num_heading_angle_bins
|
self.num_heading_angle_bins = num_heading_angle_bins
|
||||||
self.num_size_templates = num_size_templates
|
self.num_size_templates = num_size_templates
|
||||||
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
|
self.register_buffer("size_templates", size_templates.view(self.num_size_templates, 3))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
"heading_angle_bin_centers", torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, inputs, targets):
|
def forward(self, inputs, targets):
|
||||||
mask_logits = inputs['mask_logits'] # (B, 2, N)
|
mask_logits = inputs["mask_logits"] # (B, 2, N)
|
||||||
center_reg = inputs['center_reg'] # (B, 3)
|
center_reg = inputs["center_reg"] # (B, 3)
|
||||||
center = inputs['center'] # (B, 3)
|
center = inputs["center"] # (B, 3)
|
||||||
heading_scores = inputs['heading_scores'] # (B, NH)
|
heading_scores = inputs["heading_scores"] # (B, NH)
|
||||||
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
|
heading_residuals_normalized = inputs["heading_residuals_normalized"] # (B, NH)
|
||||||
heading_residuals = inputs['heading_residuals'] # (B, NH)
|
heading_residuals = inputs["heading_residuals"] # (B, NH)
|
||||||
size_scores = inputs['size_scores'] # (B, NS)
|
size_scores = inputs["size_scores"] # (B, NS)
|
||||||
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
|
size_residuals_normalized = inputs["size_residuals_normalized"] # (B, NS, 3)
|
||||||
size_residuals = inputs['size_residuals'] # (B, NS, 3)
|
size_residuals = inputs["size_residuals"] # (B, NS, 3)
|
||||||
|
|
||||||
mask_logits_target = targets['mask_logits'] # (B, N)
|
mask_logits_target = targets["mask_logits"] # (B, N)
|
||||||
center_target = targets['center'] # (B, 3)
|
center_target = targets["center"] # (B, 3)
|
||||||
heading_bin_id_target = targets['heading_bin_id'] # (B, )
|
heading_bin_id_target = targets["heading_bin_id"] # (B, )
|
||||||
heading_residual_target = targets['heading_residual'] # (B, )
|
heading_residual_target = targets["heading_residual"] # (B, )
|
||||||
size_template_id_target = targets['size_template_id'] # (B, )
|
size_template_id_target = targets["size_template_id"] # (B, )
|
||||||
size_residual_target = targets['size_residual'] # (B, 3)
|
size_residual_target = targets["size_residual"] # (B, 3)
|
||||||
|
|
||||||
batch_size = center.size(0)
|
batch_size = center.size(0)
|
||||||
batch_id = torch.arange(batch_size, device=center.device)
|
batch_id = torch.arange(batch_size, device=center.device)
|
||||||
|
@ -65,25 +73,32 @@ class FrustumPointNetLoss(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bounding box losses
|
# Bounding box losses
|
||||||
heading = (heading_residuals[batch_id, heading_bin_id_target]
|
heading = (
|
||||||
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
|
heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target]
|
||||||
|
) # (B, )
|
||||||
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
||||||
size = (size_residuals[batch_id, size_template_id_target]
|
size = (
|
||||||
+ self.size_templates[size_template_id_target]) # (B, 3)
|
size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target]
|
||||||
|
) # (B, 3)
|
||||||
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
||||||
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
||||||
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
||||||
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
|
corners_target, corners_target_flip = get_box_corners_3d(
|
||||||
sizes=size_target, with_flip=True) # (B, 3, 8)
|
centers=center_target, headings=heading_target, sizes=size_target, with_flip=True
|
||||||
corners_loss = PF.huber_loss(torch.min(
|
) # (B, 3, 8)
|
||||||
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
|
corners_loss = PF.huber_loss(
|
||||||
), delta=1.0)
|
torch.min(torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)),
|
||||||
|
delta=1.0,
|
||||||
|
)
|
||||||
# Summing up
|
# Summing up
|
||||||
loss = mask_loss + self.box_loss_weight * (
|
loss = mask_loss + self.box_loss_weight * (
|
||||||
center_loss + center_reg_loss + heading_loss + size_loss
|
center_loss
|
||||||
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
+ center_reg_loss
|
||||||
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
+ heading_loss
|
||||||
+ self.corners_loss_weight * corners_loss
|
+ size_loss
|
||||||
|
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
||||||
|
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
||||||
|
+ self.corners_loss_weight * corners_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -105,9 +120,9 @@ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
|
||||||
l = sizes[:, 0] # (N,)
|
l = sizes[:, 0] # (N,)
|
||||||
w = sizes[:, 1] # (N,)
|
w = sizes[:, 1] # (N,)
|
||||||
h = sizes[:, 2] # (N,)
|
h = sizes[:, 2] # (N,)
|
||||||
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
|
x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
|
||||||
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
|
y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
|
||||||
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
|
z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
|
||||||
|
|
||||||
c = torch.cos(headings) # (N,)
|
c = torch.cos(headings) # (N,)
|
||||||
s = torch.sin(headings) # (N,)
|
s = torch.sin(headings) # (N,)
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
from modules.functional.ball_query import ball_query
|
|
||||||
from modules.functional.devoxelization import trilinear_devoxelize
|
|
||||||
from modules.functional.grouping import grouping
|
|
||||||
from modules.functional.interpolatation import nearest_neighbor_interpolate
|
|
||||||
from modules.functional.loss import kl_loss, huber_loss
|
|
||||||
from modules.functional.sampling import gather, furthest_point_sample, logits_mask
|
|
||||||
from modules.functional.voxelization import avg_voxelize
|
|
|
@ -3,24 +3,28 @@ import os
|
||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
_backend = load(name='_pvcnn_backend',
|
_backend = load(
|
||||||
extra_cflags=['-O3', '-std=c++17'],
|
name="_pvcnn_backend",
|
||||||
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],
|
extra_cflags=["-O3", "-std=c++17"],
|
||||||
sources=[os.path.join(_src_path,'src', f) for f in [
|
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
|
||||||
'ball_query/ball_query.cpp',
|
sources=[
|
||||||
'ball_query/ball_query.cu',
|
os.path.join(_src_path, "src", f)
|
||||||
'grouping/grouping.cpp',
|
for f in [
|
||||||
'grouping/grouping.cu',
|
"ball_query/ball_query.cpp",
|
||||||
'interpolate/neighbor_interpolate.cpp',
|
"ball_query/ball_query.cu",
|
||||||
'interpolate/neighbor_interpolate.cu',
|
"grouping/grouping.cpp",
|
||||||
'interpolate/trilinear_devox.cpp',
|
"grouping/grouping.cu",
|
||||||
'interpolate/trilinear_devox.cu',
|
"interpolate/neighbor_interpolate.cpp",
|
||||||
'sampling/sampling.cpp',
|
"interpolate/neighbor_interpolate.cu",
|
||||||
'sampling/sampling.cu',
|
"interpolate/trilinear_devox.cpp",
|
||||||
'voxelization/vox.cpp',
|
"interpolate/trilinear_devox.cu",
|
||||||
'voxelization/vox.cu',
|
"sampling/sampling.cpp",
|
||||||
'bindings.cpp',
|
"sampling/sampling.cu",
|
||||||
]]
|
"voxelization/vox.cpp",
|
||||||
)
|
"voxelization/vox.cu",
|
||||||
|
"bindings.cpp",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ['_backend']
|
__all__ = ["_backend"]
|
||||||
|
|
|
@ -1,19 +1,17 @@
|
||||||
from torch.autograd import Function
|
|
||||||
|
|
||||||
from modules.functional.backend import _backend
|
from modules.functional.backend import _backend
|
||||||
|
|
||||||
__all__ = ['ball_query']
|
__all__ = ["ball_query"]
|
||||||
|
|
||||||
|
|
||||||
def ball_query(centers_coords, points_coords, radius, num_neighbors):
|
def ball_query(centers_coords, points_coords, radius, num_neighbors):
|
||||||
"""
|
"""
|
||||||
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
||||||
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
||||||
:param radius: float, radius of ball query
|
:param radius: float, radius of ball query
|
||||||
:param num_neighbors: int, maximum number of neighbors
|
:param num_neighbors: int, maximum number of neighbors
|
||||||
:return:
|
:return:
|
||||||
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
|
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
|
||||||
"""
|
"""
|
||||||
centers_coords = centers_coords.contiguous()
|
centers_coords = centers_coords.contiguous()
|
||||||
points_coords = points_coords.contiguous()
|
points_coords = points_coords.contiguous()
|
||||||
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
|
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
||||||
|
|
||||||
from modules.functional.backend import _backend
|
from modules.functional.backend import _backend
|
||||||
|
|
||||||
__all__ = ['trilinear_devoxelize']
|
__all__ = ["trilinear_devoxelize"]
|
||||||
|
|
||||||
|
|
||||||
class TrilinearDevoxelization(Function):
|
class TrilinearDevoxelization(Function):
|
||||||
|
@ -29,7 +29,7 @@ class TrilinearDevoxelization(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
"""
|
"""
|
||||||
:param ctx:
|
:param ctx:
|
||||||
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
|
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
|
||||||
:return:
|
:return:
|
||||||
gradient of inputs, FloatTensor[B, C, R, R, R]
|
gradient of inputs, FloatTensor[B, C, R, R, R]
|
||||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
||||||
|
|
||||||
from modules.functional.backend import _backend
|
from modules.functional.backend import _backend
|
||||||
|
|
||||||
__all__ = ['grouping']
|
__all__ = ["grouping"]
|
||||||
|
|
||||||
|
|
||||||
class Grouping(Function):
|
class Grouping(Function):
|
||||||
|
@ -23,7 +23,7 @@ class Grouping(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
indices, = ctx.saved_tensors
|
(indices,) = ctx.saved_tensors
|
||||||
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
|
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
|
||||||
return grad_features, None
|
return grad_features, None
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
||||||
|
|
||||||
from modules.functional.backend import _backend
|
from modules.functional.backend import _backend
|
||||||
|
|
||||||
__all__ = ['nearest_neighbor_interpolate']
|
__all__ = ["nearest_neighbor_interpolate"]
|
||||||
|
|
||||||
|
|
||||||
class NeighborInterpolation(Function):
|
class NeighborInterpolation(Function):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
__all__ = ['kl_loss', 'huber_loss']
|
__all__ = ["kl_loss", "huber_loss"]
|
||||||
|
|
||||||
|
|
||||||
def kl_loss(x, y):
|
def kl_loss(x, y):
|
||||||
|
@ -13,5 +13,5 @@ def kl_loss(x, y):
|
||||||
def huber_loss(error, delta):
|
def huber_loss(error, delta):
|
||||||
abs_error = torch.abs(error)
|
abs_error = torch.abs(error)
|
||||||
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
|
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
|
||||||
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
|
losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic)
|
||||||
return torch.mean(losses)
|
return torch.mean(losses)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from torch.autograd import Function
|
||||||
|
|
||||||
from modules.functional.backend import _backend
|
from modules.functional.backend import _backend
|
||||||
|
|
||||||
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
|
__all__ = ["gather", "furthest_point_sample", "logits_mask"]
|
||||||
|
|
||||||
|
|
||||||
class Gather(Function):
|
class Gather(Function):
|
||||||
|
@ -26,7 +26,7 @@ class Gather(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
indices, = ctx.saved_tensors
|
(indices,) = ctx.saved_tensors
|
||||||
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
|
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
|
||||||
return grad_features, None
|
return grad_features, None
|
||||||
|
|
||||||
|
@ -60,11 +60,12 @@ def logits_mask(coords, logits, num_points_per_object):
|
||||||
mask: mask to select points, BoolTensor[B, N]
|
mask: mask to select points, BoolTensor[B, N]
|
||||||
"""
|
"""
|
||||||
batch_size, _, num_points = coords.shape
|
batch_size, _, num_points = coords.shape
|
||||||
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
||||||
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
|
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
|
||||||
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
|
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
|
||||||
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
|
masked_coords_mean = (
|
||||||
torch.ones_like(num_candidates)).float() # [B, C]
|
torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, torch.ones_like(num_candidates)).float()
|
||||||
|
) # [B, C]
|
||||||
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
|
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
current_mask = mask[i] # [N]
|
current_mask = mask[i] # [N]
|
||||||
|
@ -74,10 +75,14 @@ def logits_mask(coords, logits, num_points_per_object):
|
||||||
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
|
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
|
||||||
selected_indices[i] = current_candidates[choices]
|
selected_indices[i] = current_candidates[choices]
|
||||||
elif current_num_candidates > 0:
|
elif current_num_candidates > 0:
|
||||||
choices = np.concatenate([
|
choices = np.concatenate(
|
||||||
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
|
[
|
||||||
np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
|
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
|
||||||
])
|
np.random.choice(
|
||||||
|
current_num_candidates, num_points_per_object % current_num_candidates, replace=False
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
np.random.shuffle(choices)
|
np.random.shuffle(choices)
|
||||||
selected_indices[i] = current_candidates[choices]
|
selected_indices[i] = current_candidates[choices]
|
||||||
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
|
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from torch.autograd import Function
|
||||||
|
|
||||||
from modules.functional.backend import _backend
|
from modules.functional.backend import _backend
|
||||||
|
|
||||||
__all__ = ['avg_voxelize']
|
__all__ = ["avg_voxelize"]
|
||||||
|
|
||||||
|
|
||||||
class AvgVoxelization(Function):
|
class AvgVoxelization(Function):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import modules.functional as F
|
import modules.functional as F
|
||||||
|
|
||||||
__all__ = ['KLLoss']
|
__all__ = ["KLLoss"]
|
||||||
|
|
||||||
|
|
||||||
class KLLoss(nn.Module):
|
class KLLoss(nn.Module):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import modules.functional as F
|
||||||
from modules.ball_query import BallQuery
|
from modules.ball_query import BallQuery
|
||||||
from modules.shared_mlp import SharedMLP
|
from modules.shared_mlp import SharedMLP
|
||||||
|
|
||||||
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
|
__all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"]
|
||||||
|
|
||||||
|
|
||||||
class PointNetAModule(nn.Module):
|
class PointNetAModule(nn.Module):
|
||||||
|
@ -20,8 +20,9 @@ class PointNetAModule(nn.Module):
|
||||||
total_out_channels = 0
|
total_out_channels = 0
|
||||||
for _out_channels in out_channels:
|
for _out_channels in out_channels:
|
||||||
mlps.append(
|
mlps.append(
|
||||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
SharedMLP(
|
||||||
out_channels=_out_channels, dim=1)
|
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
total_out_channels += _out_channels[-1]
|
total_out_channels += _out_channels[-1]
|
||||||
|
|
||||||
|
@ -43,7 +44,7 @@ class PointNetAModule(nn.Module):
|
||||||
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
|
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}"
|
||||||
|
|
||||||
|
|
||||||
class PointNetSAModule(nn.Module):
|
class PointNetSAModule(nn.Module):
|
||||||
|
@ -67,8 +68,9 @@ class PointNetSAModule(nn.Module):
|
||||||
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
|
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
|
||||||
)
|
)
|
||||||
mlps.append(
|
mlps.append(
|
||||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
SharedMLP(
|
||||||
out_channels=_out_channels, dim=2)
|
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2
|
||||||
|
)
|
||||||
)
|
)
|
||||||
total_out_channels += _out_channels[-1]
|
total_out_channels += _out_channels[-1]
|
||||||
|
|
||||||
|
@ -90,7 +92,7 @@ class PointNetSAModule(nn.Module):
|
||||||
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
|
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
return f"num_centers={self.num_centers}, out_channels={self.out_channels}"
|
||||||
|
|
||||||
|
|
||||||
class PointNetFPModule(nn.Module):
|
class PointNetFPModule(nn.Module):
|
||||||
|
@ -107,7 +109,5 @@ class PointNetFPModule(nn.Module):
|
||||||
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
||||||
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
|
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
|
||||||
if points_features is not None:
|
if points_features is not None:
|
||||||
interpolated_features = torch.cat(
|
interpolated_features = torch.cat([interpolated_features, points_features], dim=1)
|
||||||
[interpolated_features, points_features], dim=1
|
|
||||||
)
|
|
||||||
return self.mlp(interpolated_features), points_coords, interpolated_temb
|
return self.mlp(interpolated_features), points_coords, interpolated_temb
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
import torch
|
||||||
import modules.functional as F
|
import torch.nn as nn
|
||||||
from modules.voxelization import Voxelization
|
|
||||||
from modules.shared_mlp import SharedMLP
|
|
||||||
from modules.se import SE3d
|
|
||||||
|
|
||||||
__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU']
|
import modules.functional as F
|
||||||
|
from modules.se import SE3d
|
||||||
|
from modules.shared_mlp import SharedMLP
|
||||||
|
from modules.voxelization import Voxelization
|
||||||
|
|
||||||
|
__all__ = ["PVConv", "Attention", "Swish", "PVConvReLU"]
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class Swish(nn.Module):
|
||||||
def forward(self,x):
|
def forward(self, x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
@ -35,23 +36,19 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
self.sm = nn.Softmax(-1)
|
self.sm = nn.Softmax(-1)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C = x.shape[:2]
|
B, C = x.shape[:2]
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
|
q = self.q(h).reshape(B, C, -1)
|
||||||
|
k = self.k(h).reshape(B, C, -1)
|
||||||
|
v = self.v(h).reshape(B, C, -1)
|
||||||
|
|
||||||
|
qk = torch.matmul(q.permute(0, 2, 1), k) # * (int(C) ** (-0.5))
|
||||||
|
|
||||||
q = self.q(h).reshape(B,C,-1)
|
|
||||||
k = self.k(h).reshape(B,C,-1)
|
|
||||||
v = self.v(h).reshape(B,C,-1)
|
|
||||||
|
|
||||||
qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5))
|
|
||||||
|
|
||||||
w = self.sm(qk)
|
w = self.sm(qk)
|
||||||
|
|
||||||
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:])
|
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B, C, *x.shape[2:])
|
||||||
|
|
||||||
h = self.out(h)
|
h = self.out(h)
|
||||||
|
|
||||||
|
@ -61,9 +58,21 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class PVConv(nn.Module):
|
class PVConv(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False,
|
def __init__(
|
||||||
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
resolution,
|
||||||
|
attention=False,
|
||||||
|
dropout=0.1,
|
||||||
|
with_se=False,
|
||||||
|
with_se_relu=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
@ -74,13 +83,13 @@ class PVConv(nn.Module):
|
||||||
voxel_layers = [
|
voxel_layers = [
|
||||||
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||||
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
||||||
Swish()
|
Swish(),
|
||||||
]
|
]
|
||||||
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
|
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
|
||||||
voxel_layers += [
|
voxel_layers += [
|
||||||
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||||
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
nn.GroupNorm(num_groups=8, num_channels=out_channels),
|
||||||
Attention(out_channels, 8) if attention else Swish()
|
Attention(out_channels, 8) if attention else Swish(),
|
||||||
]
|
]
|
||||||
if with_se:
|
if with_se:
|
||||||
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
|
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
|
||||||
|
@ -96,10 +105,21 @@ class PVConv(nn.Module):
|
||||||
return fused_features, coords, temb
|
return fused_features, coords, temb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PVConvReLU(nn.Module):
|
class PVConvReLU(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2,
|
def __init__(
|
||||||
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
resolution,
|
||||||
|
attention=False,
|
||||||
|
leak=0.2,
|
||||||
|
dropout=0.1,
|
||||||
|
with_se=False,
|
||||||
|
with_se_relu=False,
|
||||||
|
normalize=True,
|
||||||
|
eps=0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
@ -110,13 +130,13 @@ class PVConvReLU(nn.Module):
|
||||||
voxel_layers = [
|
voxel_layers = [
|
||||||
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||||
nn.BatchNorm3d(out_channels),
|
nn.BatchNorm3d(out_channels),
|
||||||
nn.LeakyReLU(leak, True)
|
nn.LeakyReLU(leak, True),
|
||||||
]
|
]
|
||||||
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
|
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
|
||||||
voxel_layers += [
|
voxel_layers += [
|
||||||
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
|
||||||
nn.BatchNorm3d(out_channels),
|
nn.BatchNorm3d(out_channels),
|
||||||
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True)
|
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True),
|
||||||
]
|
]
|
||||||
if with_se:
|
if with_se:
|
||||||
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
|
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
|
||||||
|
|
|
@ -1,18 +1,22 @@
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
import torch
|
||||||
__all__ = ['SE3d']
|
import torch.nn as nn
|
||||||
|
|
||||||
|
__all__ = ["SE3d"]
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class Swish(nn.Module):
|
||||||
def forward(self,x):
|
def forward(self, x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
class SE3d(nn.Module):
|
class SE3d(nn.Module):
|
||||||
def __init__(self, channel, reduction=8, use_relu=False):
|
def __init__(self, channel, reduction=8, use_relu=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(channel, channel // reduction, bias=False),
|
nn.Linear(channel, channel // reduction, bias=False),
|
||||||
nn.ReLU(True) if use_relu else Swish() ,
|
nn.ReLU(True) if use_relu else Swish(),
|
||||||
nn.Linear(channel // reduction, channel, bias=False),
|
nn.Linear(channel // reduction, channel, bias=False),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import torch.nn as nn
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = ['SharedMLP']
|
__all__ = ["SharedMLP"]
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class Swish(nn.Module):
|
||||||
def forward(self,x):
|
def forward(self, x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
class SharedMLP(nn.Module):
|
class SharedMLP(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, dim=1):
|
def __init__(self, in_channels, out_channels, dim=1):
|
||||||
|
@ -23,11 +24,13 @@ class SharedMLP(nn.Module):
|
||||||
out_channels = [out_channels]
|
out_channels = [out_channels]
|
||||||
layers = []
|
layers = []
|
||||||
for oc in out_channels:
|
for oc in out_channels:
|
||||||
layers.extend([
|
layers.extend(
|
||||||
conv(in_channels, oc, 1),
|
[
|
||||||
bn(8, oc),
|
conv(in_channels, oc, 1),
|
||||||
Swish(),
|
bn(8, oc),
|
||||||
])
|
Swish(),
|
||||||
|
]
|
||||||
|
)
|
||||||
in_channels = oc
|
in_channels = oc
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import modules.functional as F
|
import modules.functional as F
|
||||||
|
|
||||||
__all__ = ['Voxelization']
|
__all__ = ["Voxelization"]
|
||||||
|
|
||||||
|
|
||||||
class Voxelization(nn.Module):
|
class Voxelization(nn.Module):
|
||||||
|
@ -17,7 +17,10 @@ class Voxelization(nn.Module):
|
||||||
coords = coords.detach()
|
coords = coords.detach()
|
||||||
norm_coords = coords - coords.mean(2, keepdim=True)
|
norm_coords = coords - coords.mean(2, keepdim=True)
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
|
norm_coords = (
|
||||||
|
norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps)
|
||||||
|
+ 0.5
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
norm_coords = (norm_coords + 1) / 2.0
|
norm_coords = (norm_coords + 1) / 2.0
|
||||||
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
||||||
|
@ -25,4 +28,4 @@ class Voxelization(nn.Module):
|
||||||
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
|
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
|
return "resolution={}{}".format(self.r, ", normalized eps = {}".format(self.eps) if self.normalize else "")
|
||||||
|
|
|
@ -1,26 +1,27 @@
|
||||||
|
import argparse
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
import argparse
|
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
from utils.file_utils import *
|
|
||||||
from model.pvcnn_completion import PVCNN2Base
|
|
||||||
|
|
||||||
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
|
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
|
||||||
from datasets.shapenet_data_sv import *
|
from datasets.shapenet_data_sv import *
|
||||||
'''
|
from metrics.evaluation_metrics import EMD_CD, compute_all_metrics
|
||||||
|
from model.pvcnn_completion import PVCNN2Base
|
||||||
|
from utils.file_utils import *
|
||||||
|
|
||||||
|
"""
|
||||||
models
|
models
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||||
"""
|
"""
|
||||||
KL divergence between normal distributions parameterized by mean and log-variance.
|
KL divergence between normal distributions parameterized by mean and log-variance.
|
||||||
"""
|
"""
|
||||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
|
||||||
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
|
|
||||||
|
|
||||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||||
# Assumes data is integers [0, 1]
|
# Assumes data is integers [0, 1]
|
||||||
|
@ -31,21 +32,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||||
inv_stdv = torch.exp(-log_scales)
|
inv_stdv = torch.exp(-log_scales)
|
||||||
plus_in = inv_stdv * (centered_x + 0.5)
|
plus_in = inv_stdv * (centered_x + 0.5)
|
||||||
cdf_plus = px0.cdf(plus_in)
|
cdf_plus = px0.cdf(plus_in)
|
||||||
min_in = inv_stdv * (centered_x - .5)
|
min_in = inv_stdv * (centered_x - 0.5)
|
||||||
cdf_min = px0.cdf(min_in)
|
cdf_min = px0.cdf(min_in)
|
||||||
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
|
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
||||||
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
|
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
||||||
cdf_delta = cdf_plus - cdf_min
|
cdf_delta = cdf_plus - cdf_min
|
||||||
|
|
||||||
log_probs = torch.where(
|
log_probs = torch.where(
|
||||||
x < 0.001, log_cdf_plus,
|
x < 0.001,
|
||||||
torch.where(x > 0.999, log_one_minus_cdf_min,
|
log_cdf_plus,
|
||||||
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
|
torch.where(
|
||||||
|
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
|
||||||
|
),
|
||||||
|
)
|
||||||
assert log_probs.shape == x.shape
|
assert log_probs.shape == x.shape
|
||||||
return log_probs
|
return log_probs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GaussianDiffusion:
|
class GaussianDiffusion:
|
||||||
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
|
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
@ -54,15 +57,15 @@ class GaussianDiffusion:
|
||||||
assert isinstance(betas, np.ndarray)
|
assert isinstance(betas, np.ndarray)
|
||||||
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
||||||
assert (betas > 0).all() and (betas <= 1).all()
|
assert (betas > 0).all() and (betas <= 1).all()
|
||||||
timesteps, = betas.shape
|
(timesteps,) = betas.shape
|
||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
self.sv_points = sv_points
|
self.sv_points = sv_points
|
||||||
# initialize twice the actual length so we can keep running for eval
|
# initialize twice the actual length so we can keep running for eval
|
||||||
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
||||||
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
|
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
|
||||||
|
|
||||||
self.betas = torch.from_numpy(betas).float()
|
self.betas = torch.from_numpy(betas).float()
|
||||||
self.alphas_cumprod = alphas_cumprod.float()
|
self.alphas_cumprod = alphas_cumprod.float()
|
||||||
|
@ -70,21 +73,23 @@ class GaussianDiffusion:
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
||||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
|
||||||
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
|
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
|
||||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
|
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
|
||||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
|
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
|
||||||
|
|
||||||
betas = torch.from_numpy(betas).float()
|
betas = torch.from_numpy(betas).float()
|
||||||
alphas = torch.from_numpy(alphas).float()
|
alphas = torch.from_numpy(alphas).float()
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
self.posterior_variance = posterior_variance
|
self.posterior_variance = posterior_variance
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
|
self.posterior_log_variance_clipped = torch.log(
|
||||||
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
|
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
|
||||||
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
|
)
|
||||||
|
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||||
|
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract(a, t, x_shape):
|
def _extract(a, t, x_shape):
|
||||||
|
@ -92,17 +97,15 @@ class GaussianDiffusion:
|
||||||
Extract some coefficients at specified timesteps,
|
Extract some coefficients at specified timesteps,
|
||||||
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
||||||
"""
|
"""
|
||||||
bs, = t.shape
|
(bs,) = t.shape
|
||||||
assert x_shape[0] == bs
|
assert x_shape[0] == bs
|
||||||
out = torch.gather(a, 0, t)
|
out = torch.gather(a, 0, t)
|
||||||
assert out.shape == torch.Size([bs])
|
assert out.shape == torch.Size([bs])
|
||||||
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
def q_mean_variance(self, x_start, t):
|
||||||
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||||
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||||
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||||
return mean, variance, log_variance
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
@ -114,54 +117,59 @@ class GaussianDiffusion:
|
||||||
noise = torch.randn(x_start.shape, device=x_start.device)
|
noise = torch.randn(x_start.shape, device=x_start.device)
|
||||||
assert noise.shape == x_start.shape
|
assert noise.shape == x_start.shape
|
||||||
return (
|
return (
|
||||||
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||||
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||||
"""
|
"""
|
||||||
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
||||||
"""
|
"""
|
||||||
assert x_start.shape == x_t.shape
|
assert x_start.shape == x_t.shape
|
||||||
posterior_mean = (
|
posterior_mean = (
|
||||||
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
|
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
|
||||||
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
||||||
)
|
)
|
||||||
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
||||||
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
|
posterior_log_variance_clipped = self._extract(
|
||||||
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
|
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
|
||||||
x_start.shape[0])
|
)
|
||||||
|
assert (
|
||||||
|
posterior_mean.shape[0]
|
||||||
|
== posterior_variance.shape[0]
|
||||||
|
== posterior_log_variance_clipped.shape[0]
|
||||||
|
== x_start.shape[0]
|
||||||
|
)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
|
||||||
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
||||||
|
model_output = denoise_fn(data, t)[:, :, self.sv_points :]
|
||||||
|
|
||||||
model_output = denoise_fn(data, t)[:,:,self.sv_points:]
|
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
|
||||||
|
|
||||||
|
|
||||||
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
|
|
||||||
# below: only log_variance is used in the KL computations
|
# below: only log_variance is used in the KL computations
|
||||||
model_variance, model_log_variance = {
|
model_variance, model_log_variance = {
|
||||||
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
||||||
'fixedlarge': (self.betas.to(data.device),
|
"fixedlarge": (
|
||||||
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
|
self.betas.to(data.device),
|
||||||
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
|
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
|
||||||
|
),
|
||||||
|
"fixedsmall": (
|
||||||
|
self.posterior_variance.to(data.device),
|
||||||
|
self.posterior_log_variance_clipped.to(data.device),
|
||||||
|
),
|
||||||
}[self.model_var_type]
|
}[self.model_var_type]
|
||||||
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
|
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
|
||||||
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
|
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.model_var_type)
|
raise NotImplementedError(self.model_var_type)
|
||||||
|
|
||||||
if self.model_mean_type == 'eps':
|
if self.model_mean_type == "eps":
|
||||||
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output)
|
x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output)
|
||||||
|
|
||||||
|
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t)
|
||||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.loss_type)
|
raise NotImplementedError(self.loss_type)
|
||||||
|
|
||||||
|
|
||||||
assert model_mean.shape == x_recon.shape
|
assert model_mean.shape == x_recon.shape
|
||||||
assert model_variance.shape == model_log_variance.shape
|
assert model_variance.shape == model_log_variance.shape
|
||||||
if return_pred_xstart:
|
if return_pred_xstart:
|
||||||
|
@ -172,30 +180,31 @@ class GaussianDiffusion:
|
||||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||||
assert x_t.shape == eps.shape
|
assert x_t.shape == eps.shape
|
||||||
return (
|
return (
|
||||||
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
|
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
|
||||||
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
||||||
)
|
)
|
||||||
|
|
||||||
''' samples '''
|
""" samples """
|
||||||
|
|
||||||
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
|
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
|
||||||
"""
|
"""
|
||||||
Sample from the model
|
Sample from the model
|
||||||
"""
|
"""
|
||||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
|
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
||||||
return_pred_xstart=True)
|
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
||||||
|
)
|
||||||
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
|
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
|
||||||
|
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
|
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
|
||||||
|
|
||||||
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
|
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
|
||||||
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1)
|
sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1)
|
||||||
return (sample, pred_xstart) if return_pred_xstart else sample
|
return (sample, pred_xstart) if return_pred_xstart else sample
|
||||||
|
|
||||||
|
def p_sample_loop(
|
||||||
def p_sample_loop(self, partial_x, denoise_fn, shape, device,
|
self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False
|
||||||
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
):
|
||||||
"""
|
"""
|
||||||
Generate samples
|
Generate samples
|
||||||
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
||||||
|
@ -206,14 +215,21 @@ class GaussianDiffusion:
|
||||||
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
|
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
|
||||||
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
|
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
|
||||||
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
||||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
|
img_t = self.p_sample(
|
||||||
clip_denoised=clip_denoised, return_pred_xstart=False)
|
denoise_fn=denoise_fn,
|
||||||
|
data=img_t,
|
||||||
|
t=t_,
|
||||||
|
noise_fn=noise_fn,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
return_pred_xstart=False,
|
||||||
|
)
|
||||||
|
|
||||||
assert img_t[:,:,self.sv_points:].shape == shape
|
assert img_t[:, :, self.sv_points :].shape == shape
|
||||||
return img_t
|
return img_t
|
||||||
|
|
||||||
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq,
|
def p_sample_loop_trajectory(
|
||||||
noise_fn=torch.randn,clip_denoised=True, keep_running=False):
|
self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generate samples, returning intermediate images
|
Generate samples, returning intermediate images
|
||||||
Useful for visualizing how denoised images evolve over time
|
Useful for visualizing how denoised images evolve over time
|
||||||
|
@ -223,31 +239,38 @@ class GaussianDiffusion:
|
||||||
"""
|
"""
|
||||||
assert isinstance(shape, (tuple, list))
|
assert isinstance(shape, (tuple, list))
|
||||||
|
|
||||||
total_steps = self.num_timesteps if not keep_running else len(self.betas)
|
total_steps = self.num_timesteps if not keep_running else len(self.betas)
|
||||||
|
|
||||||
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
|
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
|
||||||
imgs = [img_t]
|
imgs = [img_t]
|
||||||
for t in reversed(range(0,total_steps)):
|
for t in reversed(range(0, total_steps)):
|
||||||
|
|
||||||
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
||||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
|
img_t = self.p_sample(
|
||||||
clip_denoised=clip_denoised,
|
denoise_fn=denoise_fn,
|
||||||
return_pred_xstart=False)
|
data=img_t,
|
||||||
if t % freq == 0 or t == total_steps-1:
|
t=t_,
|
||||||
|
noise_fn=noise_fn,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
return_pred_xstart=False,
|
||||||
|
)
|
||||||
|
if t % freq == 0 or t == total_steps - 1:
|
||||||
imgs.append(img_t)
|
imgs.append(img_t)
|
||||||
|
|
||||||
assert imgs[-1].shape == shape
|
assert imgs[-1].shape == shape
|
||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
'''losses'''
|
"""losses"""
|
||||||
|
|
||||||
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
|
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
|
||||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t)
|
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
||||||
|
x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t
|
||||||
|
)
|
||||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
||||||
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
|
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
||||||
|
)
|
||||||
|
|
||||||
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
|
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
|
||||||
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.)
|
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0)
|
||||||
|
|
||||||
return (kl, pred_xstart) if return_pred_xstart else kl
|
return (kl, pred_xstart) if return_pred_xstart else kl
|
||||||
|
|
||||||
|
@ -259,66 +282,87 @@ class GaussianDiffusion:
|
||||||
assert t.shape == torch.Size([B])
|
assert t.shape == torch.Size([B])
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device)
|
noise = torch.randn(
|
||||||
|
data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device
|
||||||
|
)
|
||||||
|
|
||||||
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise)
|
data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise)
|
||||||
|
|
||||||
if self.loss_type == 'mse':
|
if self.loss_type == "mse":
|
||||||
# predict the noise instead of x_start. seems to be weighted naturally like SNR
|
# predict the noise instead of x_start. seems to be weighted naturally like SNR
|
||||||
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:]
|
eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[
|
||||||
|
:, :, self.sv_points :
|
||||||
|
]
|
||||||
|
|
||||||
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
|
losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape))))
|
||||||
elif self.loss_type == 'kl':
|
elif self.loss_type == "kl":
|
||||||
losses = self._vb_terms_bpd(
|
losses = self._vb_terms_bpd(
|
||||||
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False,
|
denoise_fn=denoise_fn,
|
||||||
return_pred_xstart=False)
|
data_start=data_start,
|
||||||
|
data_t=data_t,
|
||||||
|
t=t,
|
||||||
|
clip_denoised=False,
|
||||||
|
return_pred_xstart=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.loss_type)
|
raise NotImplementedError(self.loss_type)
|
||||||
|
|
||||||
assert losses.shape == torch.Size([B])
|
assert losses.shape == torch.Size([B])
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
'''debug'''
|
"""debug"""
|
||||||
|
|
||||||
def _prior_bpd(self, x_start):
|
def _prior_bpd(self, x_start):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
B, T = x_start.shape[0], self.num_timesteps
|
B, T = x_start.shape[0], self.num_timesteps
|
||||||
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1)
|
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1)
|
||||||
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
|
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
|
||||||
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance,
|
kl_prior = normal_kl(
|
||||||
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance))
|
mean1=qt_mean,
|
||||||
|
logvar1=qt_log_variance,
|
||||||
|
mean2=torch.tensor([0.0]).to(qt_mean),
|
||||||
|
logvar2=torch.tensor([0.0]).to(qt_log_variance),
|
||||||
|
)
|
||||||
assert kl_prior.shape == x_start.shape
|
assert kl_prior.shape == x_start.shape
|
||||||
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.)
|
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0)
|
||||||
|
|
||||||
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
|
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
B, T = x_start.shape[0], self.num_timesteps
|
B, T = x_start.shape[0], self.num_timesteps
|
||||||
|
|
||||||
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
|
vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
|
||||||
for t in reversed(range(T)):
|
for t in reversed(range(T)):
|
||||||
|
|
||||||
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
|
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
|
||||||
# Calculate VLB term at the current timestep
|
# Calculate VLB term at the current timestep
|
||||||
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1)
|
data_t = torch.cat(
|
||||||
|
[x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
new_vals_b, pred_xstart = self._vb_terms_bpd(
|
new_vals_b, pred_xstart = self._vb_terms_bpd(
|
||||||
denoise_fn, data_start=x_start, data_t=data_t, t=t_b,
|
denoise_fn,
|
||||||
clip_denoised=clip_denoised, return_pred_xstart=True)
|
data_start=x_start,
|
||||||
|
data_t=data_t,
|
||||||
|
t=t_b,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
return_pred_xstart=True,
|
||||||
|
)
|
||||||
# MSE for progressive prediction loss
|
# MSE for progressive prediction loss
|
||||||
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape
|
assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape
|
||||||
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape))))
|
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean(
|
||||||
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
|
dim=list(range(1, len(pred_xstart.shape)))
|
||||||
|
)
|
||||||
|
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
|
||||||
# Insert the calculated term into the tensor of all terms
|
# Insert the calculated term into the tensor of all terms
|
||||||
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float()
|
mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float()
|
||||||
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
|
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
|
||||||
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
|
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
|
||||||
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
|
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
|
||||||
|
|
||||||
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:])
|
prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :])
|
||||||
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
|
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
|
||||||
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \
|
assert vals_bt_.shape == mse_bt_.shape == torch.Size(
|
||||||
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
|
[B, T]
|
||||||
|
) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
|
||||||
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
|
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
|
||||||
|
|
||||||
|
|
||||||
|
@ -336,39 +380,53 @@ class PVCNN2(PVCNN2Base):
|
||||||
((128, 128, 64), (64, 2, 32)),
|
((128, 128, 64), (64, 2, 32)),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
|
def __init__(
|
||||||
voxel_resolution_multiplier=1):
|
self,
|
||||||
|
num_classes,
|
||||||
|
sv_points,
|
||||||
|
embed_dim,
|
||||||
|
use_att,
|
||||||
|
dropout,
|
||||||
|
extra_feature_channels=3,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att,
|
num_classes=num_classes,
|
||||||
dropout=dropout, extra_feature_channels=extra_feature_channels,
|
sv_points=sv_points,
|
||||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
embed_dim=embed_dim,
|
||||||
|
use_att=use_att,
|
||||||
|
dropout=dropout,
|
||||||
|
extra_feature_channels=extra_feature_channels,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
|
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
|
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
|
||||||
|
|
||||||
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention,
|
self.model = PVCNN2(
|
||||||
dropout=args.dropout, extra_feature_channels=0)
|
num_classes=args.nc,
|
||||||
|
sv_points=args.svpoints,
|
||||||
|
embed_dim=args.embed_dim,
|
||||||
|
use_att=args.attention,
|
||||||
|
dropout=args.dropout,
|
||||||
|
extra_feature_channels=0,
|
||||||
|
)
|
||||||
|
|
||||||
def prior_kl(self, x0):
|
def prior_kl(self, x0):
|
||||||
return self.diffusion._prior_bpd(x0)
|
return self.diffusion._prior_bpd(x0)
|
||||||
|
|
||||||
def all_kl(self, x0, clip_denoised=True):
|
def all_kl(self, x0, clip_denoised=True):
|
||||||
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
||||||
|
|
||||||
return {
|
|
||||||
'total_bpd_b': total_bpd_b,
|
|
||||||
'terms_bpd': vals_bt,
|
|
||||||
'prior_bpd_b': prior_bpd_b,
|
|
||||||
'mse_bt':mse_bt
|
|
||||||
}
|
|
||||||
|
|
||||||
|
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
|
||||||
|
|
||||||
def _denoise(self, data, t):
|
def _denoise(self, data, t):
|
||||||
B, D,N= data.shape
|
B, D, N = data.shape
|
||||||
assert data.dtype == torch.float
|
assert data.dtype == torch.float
|
||||||
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
||||||
|
|
||||||
|
@ -381,20 +439,22 @@ class Model(nn.Module):
|
||||||
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
||||||
|
|
||||||
if noises is not None:
|
if noises is not None:
|
||||||
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
|
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
|
||||||
|
|
||||||
losses = self.diffusion.p_losses(
|
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
||||||
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
|
||||||
assert losses.shape == t.shape == torch.Size([B])
|
assert losses.shape == t.shape == torch.Size([B])
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn,
|
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
|
||||||
clip_denoised=True,
|
return self.diffusion.p_sample_loop(
|
||||||
keep_running=False):
|
partial_x,
|
||||||
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn,
|
self._denoise,
|
||||||
clip_denoised=clip_denoised,
|
shape=shape,
|
||||||
keep_running=keep_running)
|
device=device,
|
||||||
|
noise_fn=noise_fn,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
keep_running=keep_running,
|
||||||
|
)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
@ -405,21 +465,19 @@ class Model(nn.Module):
|
||||||
def multi_gpu_wrapper(self, f):
|
def multi_gpu_wrapper(self, f):
|
||||||
self.model = f(self.model)
|
self.model = f(self.model)
|
||||||
|
|
||||||
def get_betas(schedule_type, b_start, b_end, time_num):
|
|
||||||
if schedule_type == 'linear':
|
|
||||||
betas = np.linspace(b_start, b_end, time_num)
|
|
||||||
elif schedule_type == 'warm0.1':
|
|
||||||
|
|
||||||
|
def get_betas(schedule_type, b_start, b_end, time_num):
|
||||||
|
if schedule_type == "linear":
|
||||||
|
betas = np.linspace(b_start, b_end, time_num)
|
||||||
|
elif schedule_type == "warm0.1":
|
||||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||||
warmup_time = int(time_num * 0.1)
|
warmup_time = int(time_num * 0.1)
|
||||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||||
elif schedule_type == 'warm0.2':
|
elif schedule_type == "warm0.2":
|
||||||
|
|
||||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||||
warmup_time = int(time_num * 0.2)
|
warmup_time = int(time_num * 0.2)
|
||||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||||
elif schedule_type == 'warm0.5':
|
elif schedule_type == "warm0.5":
|
||||||
|
|
||||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||||
warmup_time = int(time_num * 0.5)
|
warmup_time = int(time_num * 0.5)
|
||||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||||
|
@ -428,22 +486,29 @@ def get_betas(schedule_type, b_start, b_end, time_num):
|
||||||
return betas
|
return betas
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
|
||||||
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
|
|
||||||
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
|
def get_mvr_dataset(pc_dataroot, views_root, npoints, category):
|
||||||
categories=[category], split='train',
|
tr_dataset = ShapeNet15kPointClouds(
|
||||||
|
root_dir=pc_dataroot,
|
||||||
|
categories=[category],
|
||||||
|
split="train",
|
||||||
tr_sample_size=npoints,
|
tr_sample_size=npoints,
|
||||||
te_sample_size=npoints,
|
te_sample_size=npoints,
|
||||||
scale=1.,
|
scale=1.0,
|
||||||
normalize_per_shape=False,
|
normalize_per_shape=False,
|
||||||
normalize_std_per_axis=False,
|
normalize_std_per_axis=False,
|
||||||
random_subsample=True)
|
random_subsample=True,
|
||||||
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
|
)
|
||||||
cache=os.path.join(pc_dataroot, '../cache'), split='val',
|
te_dataset = ShapeNet_Multiview_Points(
|
||||||
|
root_pc=pc_dataroot,
|
||||||
|
root_views=views_root,
|
||||||
|
cache=os.path.join(pc_dataroot, "../cache"),
|
||||||
|
split="val",
|
||||||
categories=[category],
|
categories=[category],
|
||||||
npoints=npoints, sv_samples=200,
|
npoints=npoints,
|
||||||
|
sv_samples=200,
|
||||||
all_points_mean=tr_dataset.all_points_mean,
|
all_points_mean=tr_dataset.all_points_mean,
|
||||||
all_points_std=tr_dataset.all_points_std,
|
all_points_std=tr_dataset.all_points_std,
|
||||||
)
|
)
|
||||||
|
@ -451,39 +516,41 @@ def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
|
||||||
|
|
||||||
|
|
||||||
def evaluate_recon_mvr(opt, model, save_dir, logger):
|
def evaluate_recon_mvr(opt, model, save_dir, logger):
|
||||||
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv,
|
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.category)
|
||||||
opt.npoints, opt.category)
|
|
||||||
|
|
||||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
|
test_dataloader = torch.utils.data.DataLoader(
|
||||||
shuffle=False, num_workers=int(opt.workers), drop_last=False)
|
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
||||||
|
)
|
||||||
ref = []
|
ref = []
|
||||||
samples = []
|
samples = []
|
||||||
masked = []
|
masked = []
|
||||||
k = 0
|
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Reconstructing Samples"):
|
||||||
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'):
|
gt_all = data["test_points"]
|
||||||
|
x_all = data["sv_points"]
|
||||||
|
|
||||||
gt_all = data['test_points']
|
B, V, N, C = x_all.shape
|
||||||
x_all = data['sv_points']
|
gt_all = gt_all[:, None, :, :].expand(-1, V, -1, -1)
|
||||||
|
|
||||||
B,V,N,C = x_all.shape
|
|
||||||
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
|
|
||||||
|
|
||||||
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
|
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
m, s = data['mean'].float(), data['std'].float()
|
m, s = data["mean"].float(), data["std"].float()
|
||||||
|
|
||||||
recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda',
|
recon = (
|
||||||
clip_denoised=False).detach().cpu()
|
model.gen_samples(
|
||||||
|
x[:, :, : opt.svpoints].cuda(), x[:, :, opt.svpoints :].shape, "cuda", clip_denoised=False
|
||||||
|
)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
)
|
||||||
|
|
||||||
recon = recon.transpose(1, 2).contiguous()
|
recon = recon.transpose(1, 2).contiguous()
|
||||||
x = x.transpose(1, 2).contiguous()
|
x = x.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
x_adj = x.reshape(B, V, N, C) * s + m
|
||||||
|
recon_adj = recon.reshape(B, V, N, C) * s + m
|
||||||
|
|
||||||
x_adj = x.reshape(B,V,N,C)* s + m
|
ref.append(gt_all * s + m)
|
||||||
recon_adj = recon.reshape(B,V,N,C)* s + m
|
masked.append(x_adj[:, :, : test_dataloader.dataset.sv_samples, :])
|
||||||
|
|
||||||
ref.append( gt_all * s + m)
|
|
||||||
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
|
|
||||||
samples.append(recon_adj)
|
samples.append(recon_adj)
|
||||||
|
|
||||||
ref_pcs = torch.cat(ref, dim=0)
|
ref_pcs = torch.cat(ref, dim=0)
|
||||||
|
@ -492,31 +559,40 @@ def evaluate_recon_mvr(opt, model, save_dir, logger):
|
||||||
|
|
||||||
B, V, N, C = ref_pcs.shape
|
B, V, N, C = ref_pcs.shape
|
||||||
|
|
||||||
|
torch.save(ref_pcs.reshape(B, V, N, C), os.path.join(save_dir, "recon_gt.pth"))
|
||||||
|
|
||||||
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth'))
|
torch.save(masked.reshape(B, V, *masked.shape[2:]), os.path.join(save_dir, "recon_masked.pth"))
|
||||||
|
|
||||||
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
|
|
||||||
# Compute metrics
|
# Compute metrics
|
||||||
results = EMD_CD(sample_pcs.reshape(B*V, N, C),
|
results = EMD_CD(sample_pcs.reshape(B * V, N, C), ref_pcs.reshape(B * V, N, C), opt.batch_size, reduced=False)
|
||||||
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
|
|
||||||
|
|
||||||
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()}
|
results = {
|
||||||
|
ky: val.reshape(B, V)
|
||||||
|
if val.shape
|
||||||
|
== torch.Size(
|
||||||
|
[
|
||||||
|
B * V,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else val
|
||||||
|
for ky, val in results.items()
|
||||||
|
}
|
||||||
|
|
||||||
pprint({key: val.mean().item() for key, val in results.items()})
|
pprint({key: val.mean().item() for key, val in results.items()})
|
||||||
logger.info({key: val.mean().item() for key, val in results.items()})
|
logger.info({key: val.mean().item() for key, val in results.items()})
|
||||||
|
|
||||||
results['pc'] = sample_pcs
|
results["pc"] = sample_pcs
|
||||||
torch.save(results, os.path.join(save_dir, 'ours_results.pth'))
|
torch.save(results, os.path.join(save_dir, "ours_results.pth"))
|
||||||
|
|
||||||
del ref_pcs, masked, results
|
del ref_pcs, masked, results
|
||||||
|
|
||||||
|
|
||||||
def evaluate_saved(opt, saved_dir):
|
def evaluate_saved(opt, saved_dir):
|
||||||
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
|
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
|
||||||
|
|
||||||
gt_pth = saved_dir + '/recon_gt.pth'
|
gt_pth = saved_dir + "/recon_gt.pth"
|
||||||
ours_pth = saved_dir + '/ours_results.pth'
|
ours_pth = saved_dir + "/ours_results.pth"
|
||||||
gt = torch.load(gt_pth).permute(1,0,2,3)
|
gt = torch.load(gt_pth).permute(1, 0, 2, 3)
|
||||||
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3)
|
ours = torch.load(ours_pth)["pc"].permute(1, 0, 2, 3)
|
||||||
|
|
||||||
all_res = {}
|
all_res = {}
|
||||||
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
|
for i, (gt_, ours_) in enumerate(zip(gt, ours)):
|
||||||
|
@ -534,7 +610,6 @@ def evaluate_saved(opt, saved_dir):
|
||||||
pprint({key: val.mean().item() for key, val in all_res.items()})
|
pprint({key: val.mean().item() for key, val in all_res.items()})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
||||||
dir_id = os.path.dirname(__file__)
|
dir_id = os.path.dirname(__file__)
|
||||||
|
@ -542,7 +617,7 @@ def main(opt):
|
||||||
copy_source(__file__, output_dir)
|
copy_source(__file__, output_dir)
|
||||||
logger = setup_logging(output_dir)
|
logger = setup_logging(output_dir)
|
||||||
|
|
||||||
outf_syn, = setup_output_subdirs(output_dir, 'syn')
|
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
|
||||||
|
|
||||||
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
||||||
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
||||||
|
@ -559,12 +634,10 @@ def main(opt):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
logger.info("Resume Path:%s" % opt.model)
|
logger.info("Resume Path:%s" % opt.model)
|
||||||
|
|
||||||
resumed_param = torch.load(opt.model)
|
resumed_param = torch.load(opt.model)
|
||||||
model.load_state_dict(resumed_param['model_state'])
|
model.load_state_dict(resumed_param["model_state"])
|
||||||
|
|
||||||
|
|
||||||
if opt.eval_recon_mvr:
|
if opt.eval_recon_mvr:
|
||||||
# Evaluate generation
|
# Evaluate generation
|
||||||
|
@ -575,47 +648,44 @@ def main(opt):
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/')
|
parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/")
|
||||||
parser.add_argument('--dataroot_sv', default='GenReData/')
|
parser.add_argument("--dataroot_sv", default="GenReData/")
|
||||||
parser.add_argument('--category', default='chair')
|
parser.add_argument("--category", default="chair")
|
||||||
|
|
||||||
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
|
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
|
||||||
parser.add_argument('--workers', type=int, default=16, help='workers')
|
parser.add_argument("--workers", type=int, default=16, help="workers")
|
||||||
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
|
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
|
||||||
|
|
||||||
parser.add_argument('--eval_recon_mvr', default=True)
|
parser.add_argument("--eval_recon_mvr", default=True)
|
||||||
parser.add_argument('--eval_saved', default=True)
|
parser.add_argument("--eval_saved", default=True)
|
||||||
|
|
||||||
parser.add_argument('--nc', default=3)
|
parser.add_argument("--nc", default=3)
|
||||||
parser.add_argument('--npoints', default=2048)
|
parser.add_argument("--npoints", default=2048)
|
||||||
parser.add_argument('--svpoints', default=200)
|
parser.add_argument("--svpoints", default=200)
|
||||||
'''model'''
|
"""model"""
|
||||||
parser.add_argument('--beta_start', default=0.0001)
|
parser.add_argument("--beta_start", default=0.0001)
|
||||||
parser.add_argument('--beta_end', default=0.02)
|
parser.add_argument("--beta_end", default=0.02)
|
||||||
parser.add_argument('--schedule_type', default='linear')
|
parser.add_argument("--schedule_type", default="linear")
|
||||||
parser.add_argument('--time_num', default=1000)
|
parser.add_argument("--time_num", default=1000)
|
||||||
|
|
||||||
#params
|
# params
|
||||||
parser.add_argument('--attention', default=True)
|
parser.add_argument("--attention", default=True)
|
||||||
parser.add_argument('--dropout', default=0.1)
|
parser.add_argument("--dropout", default=0.1)
|
||||||
parser.add_argument('--embed_dim', type=int, default=64)
|
parser.add_argument("--embed_dim", type=int, default=64)
|
||||||
parser.add_argument('--loss_type', default='mse')
|
parser.add_argument("--loss_type", default="mse")
|
||||||
parser.add_argument('--model_mean_type', default='eps')
|
parser.add_argument("--model_mean_type", default="eps")
|
||||||
parser.add_argument('--model_var_type', default='fixedsmall')
|
parser.add_argument("--model_var_type", default="fixedsmall")
|
||||||
|
|
||||||
|
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
|
||||||
|
|
||||||
parser.add_argument('--model', default='', required=True, help="path to model (to continue training)")
|
"""eval"""
|
||||||
|
|
||||||
'''eval'''
|
parser.add_argument("--eval_path", default="")
|
||||||
|
|
||||||
parser.add_argument('--eval_path',
|
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
|
||||||
default='')
|
|
||||||
|
|
||||||
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
|
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
|
||||||
|
|
||||||
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
@ -625,7 +695,9 @@ def parse_args():
|
||||||
opt.cuda = False
|
opt.cuda = False
|
||||||
|
|
||||||
return opt
|
return opt
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
opt = parse_args()
|
opt = parse_args()
|
||||||
|
|
||||||
main(opt)
|
main(opt)
|
||||||
|
|
|
@ -1,31 +1,30 @@
|
||||||
import torch
|
import argparse
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
|
||||||
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
|
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
import argparse
|
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
|
|
||||||
from utils.file_utils import *
|
|
||||||
from utils.visualize import *
|
|
||||||
from model.pvcnn_generation import PVCNN2Base
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
|
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
|
||||||
|
from metrics.evaluation_metrics import compute_all_metrics
|
||||||
|
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
||||||
|
from model.pvcnn_generation import PVCNN2Base
|
||||||
|
from utils.file_utils import *
|
||||||
|
from utils.visualize import *
|
||||||
|
|
||||||
'''
|
"""
|
||||||
models
|
models
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||||
"""
|
"""
|
||||||
KL divergence between normal distributions parameterized by mean and log-variance.
|
KL divergence between normal distributions parameterized by mean and log-variance.
|
||||||
"""
|
"""
|
||||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
|
||||||
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
|
|
||||||
|
|
||||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||||
# Assumes data is integers [0, 1]
|
# Assumes data is integers [0, 1]
|
||||||
|
@ -36,37 +35,40 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||||
inv_stdv = torch.exp(-log_scales)
|
inv_stdv = torch.exp(-log_scales)
|
||||||
plus_in = inv_stdv * (centered_x + 0.5)
|
plus_in = inv_stdv * (centered_x + 0.5)
|
||||||
cdf_plus = px0.cdf(plus_in)
|
cdf_plus = px0.cdf(plus_in)
|
||||||
min_in = inv_stdv * (centered_x - .5)
|
min_in = inv_stdv * (centered_x - 0.5)
|
||||||
cdf_min = px0.cdf(min_in)
|
cdf_min = px0.cdf(min_in)
|
||||||
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12))
|
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
|
||||||
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12))
|
log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
|
||||||
cdf_delta = cdf_plus - cdf_min
|
cdf_delta = cdf_plus - cdf_min
|
||||||
|
|
||||||
log_probs = torch.where(
|
log_probs = torch.where(
|
||||||
x < 0.001, log_cdf_plus,
|
x < 0.001,
|
||||||
torch.where(x > 0.999, log_one_minus_cdf_min,
|
log_cdf_plus,
|
||||||
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12))))
|
torch.where(
|
||||||
|
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
|
||||||
|
),
|
||||||
|
)
|
||||||
assert log_probs.shape == x.shape
|
assert log_probs.shape == x.shape
|
||||||
return log_probs
|
return log_probs
|
||||||
|
|
||||||
|
|
||||||
class GaussianDiffusion:
|
class GaussianDiffusion:
|
||||||
def __init__(self,betas, loss_type, model_mean_type, model_var_type):
|
def __init__(self, betas, loss_type, model_mean_type, model_var_type):
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
self.model_mean_type = model_mean_type
|
self.model_mean_type = model_mean_type
|
||||||
self.model_var_type = model_var_type
|
self.model_var_type = model_var_type
|
||||||
assert isinstance(betas, np.ndarray)
|
assert isinstance(betas, np.ndarray)
|
||||||
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
|
||||||
assert (betas > 0).all() and (betas <= 1).all()
|
assert (betas > 0).all() and (betas <= 1).all()
|
||||||
timesteps, = betas.shape
|
(timesteps,) = betas.shape
|
||||||
self.num_timesteps = int(timesteps)
|
self.num_timesteps = int(timesteps)
|
||||||
|
|
||||||
# initialize twice the actual length so we can keep running for eval
|
# initialize twice the actual length so we can keep running for eval
|
||||||
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
|
||||||
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float()
|
alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
|
||||||
|
|
||||||
self.betas = torch.from_numpy(betas).float()
|
self.betas = torch.from_numpy(betas).float()
|
||||||
self.alphas_cumprod = alphas_cumprod.float()
|
self.alphas_cumprod = alphas_cumprod.float()
|
||||||
|
@ -74,21 +76,23 @@ class GaussianDiffusion:
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
|
||||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float()
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
|
||||||
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float()
|
self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
|
||||||
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
|
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
|
||||||
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
|
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
|
||||||
|
|
||||||
betas = torch.from_numpy(betas).float()
|
betas = torch.from_numpy(betas).float()
|
||||||
alphas = torch.from_numpy(alphas).float()
|
alphas = torch.from_numpy(alphas).float()
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||||
self.posterior_variance = posterior_variance
|
self.posterior_variance = posterior_variance
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||||
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
|
self.posterior_log_variance_clipped = torch.log(
|
||||||
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
|
torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
|
||||||
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
|
)
|
||||||
|
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||||
|
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract(a, t, x_shape):
|
def _extract(a, t, x_shape):
|
||||||
|
@ -96,17 +100,15 @@ class GaussianDiffusion:
|
||||||
Extract some coefficients at specified timesteps,
|
Extract some coefficients at specified timesteps,
|
||||||
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
|
||||||
"""
|
"""
|
||||||
bs, = t.shape
|
(bs,) = t.shape
|
||||||
assert x_shape[0] == bs
|
assert x_shape[0] == bs
|
||||||
out = torch.gather(a, 0, t)
|
out = torch.gather(a, 0, t)
|
||||||
assert out.shape == torch.Size([bs])
|
assert out.shape == torch.Size([bs])
|
||||||
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def q_mean_variance(self, x_start, t):
|
def q_mean_variance(self, x_start, t):
|
||||||
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||||
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||||
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
|
||||||
return mean, variance, log_variance
|
return mean, variance, log_variance
|
||||||
|
|
||||||
|
@ -118,56 +120,62 @@ class GaussianDiffusion:
|
||||||
noise = torch.randn(x_start.shape, device=x_start.device)
|
noise = torch.randn(x_start.shape, device=x_start.device)
|
||||||
assert noise.shape == x_start.shape
|
assert noise.shape == x_start.shape
|
||||||
return (
|
return (
|
||||||
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
|
||||||
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
+ self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||||
"""
|
"""
|
||||||
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
|
||||||
"""
|
"""
|
||||||
assert x_start.shape == x_t.shape
|
assert x_start.shape == x_t.shape
|
||||||
posterior_mean = (
|
posterior_mean = (
|
||||||
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
|
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
|
||||||
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
+ self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
|
||||||
)
|
)
|
||||||
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
|
||||||
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
|
posterior_log_variance_clipped = self._extract(
|
||||||
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
|
self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
|
||||||
x_start.shape[0])
|
)
|
||||||
|
assert (
|
||||||
|
posterior_mean.shape[0]
|
||||||
|
== posterior_variance.shape[0]
|
||||||
|
== posterior_log_variance_clipped.shape[0]
|
||||||
|
== x_start.shape[0]
|
||||||
|
)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
|
|
||||||
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
|
||||||
|
|
||||||
model_output = denoise_fn(data, t)
|
model_output = denoise_fn(data, t)
|
||||||
|
|
||||||
|
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
|
||||||
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
|
|
||||||
# below: only log_variance is used in the KL computations
|
# below: only log_variance is used in the KL computations
|
||||||
model_variance, model_log_variance = {
|
model_variance, model_log_variance = {
|
||||||
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
|
||||||
'fixedlarge': (self.betas.to(data.device),
|
"fixedlarge": (
|
||||||
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
|
self.betas.to(data.device),
|
||||||
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
|
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
|
||||||
|
),
|
||||||
|
"fixedsmall": (
|
||||||
|
self.posterior_variance.to(data.device),
|
||||||
|
self.posterior_log_variance_clipped.to(data.device),
|
||||||
|
),
|
||||||
}[self.model_var_type]
|
}[self.model_var_type]
|
||||||
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
|
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
|
||||||
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
|
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.model_var_type)
|
raise NotImplementedError(self.model_var_type)
|
||||||
|
|
||||||
if self.model_mean_type == 'eps':
|
if self.model_mean_type == "eps":
|
||||||
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
|
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon = torch.clamp(x_recon, -.5, .5)
|
x_recon = torch.clamp(x_recon, -0.5, 0.5)
|
||||||
|
|
||||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
|
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self.loss_type)
|
raise NotImplementedError(self.loss_type)
|
||||||
|
|
||||||
|
|
||||||
assert model_mean.shape == x_recon.shape == data.shape
|
assert model_mean.shape == x_recon.shape == data.shape
|
||||||
assert model_variance.shape == model_log_variance.shape == data.shape
|
assert model_variance.shape == model_log_variance.shape == data.shape
|
||||||
if return_pred_xstart:
|
if return_pred_xstart:
|
||||||
|
@ -178,18 +186,19 @@ class GaussianDiffusion:
|
||||||
def _predict_xstart_from_eps(self, x_t, t, eps):
|
def _predict_xstart_from_eps(self, x_t, t, eps):
|
||||||
assert x_t.shape == eps.shape
|
assert x_t.shape == eps.shape
|
||||||
return (
|
return (
|
||||||
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
|
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
|
||||||
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
- self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
|
||||||
)
|
)
|
||||||
|
|
||||||
''' samples '''
|
""" samples """
|
||||||
|
|
||||||
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
|
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
|
||||||
"""
|
"""
|
||||||
Sample from the model
|
Sample from the model
|
||||||
"""
|
"""
|
||||||
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised,
|
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
|
||||||
return_pred_xstart=True)
|
denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
|
||||||
|
)
|
||||||
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
|
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
|
||||||
assert noise.shape == data.shape
|
assert noise.shape == data.shape
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
|
@ -201,10 +210,17 @@ class GaussianDiffusion:
|
||||||
assert sample.shape == pred_xstart.shape
|
assert sample.shape == pred_xstart.shape
|
||||||
return (sample, pred_xstart) if return_pred_xstart else sample
|
return (sample, pred_xstart) if return_pred_xstart else sample
|
||||||
|
|
||||||
|
def p_sample_loop(
|
||||||
def p_sample_loop(self, denoise_fn, shape, device,
|
self,
|
||||||
noise_fn=torch.randn, constrain_fn=lambda x, t:x,
|
denoise_fn,
|
||||||
clip_denoised=True, max_timestep=None, keep_running=False):
|
shape,
|
||||||
|
device,
|
||||||
|
noise_fn=torch.randn,
|
||||||
|
constrain_fn=lambda x, t: x,
|
||||||
|
clip_denoised=True,
|
||||||
|
max_timestep=None,
|
||||||
|
keep_running=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generate samples
|
Generate samples
|
||||||
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
|
||||||
|
@ -220,28 +236,38 @@ class GaussianDiffusion:
|
||||||
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
|
for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
|
||||||
img_t = constrain_fn(img_t, t)
|
img_t = constrain_fn(img_t, t)
|
||||||
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
|
||||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn,
|
img_t = self.p_sample(
|
||||||
clip_denoised=clip_denoised, return_pred_xstart=False).detach()
|
denoise_fn=denoise_fn,
|
||||||
|
data=img_t,
|
||||||
|
t=t_,
|
||||||
|
noise_fn=noise_fn,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
return_pred_xstart=False,
|
||||||
|
).detach()
|
||||||
|
|
||||||
assert img_t.shape == shape
|
assert img_t.shape == shape
|
||||||
return img_t
|
return img_t
|
||||||
|
|
||||||
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x):
|
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t: x):
|
||||||
|
|
||||||
assert t >= 1
|
assert t >= 1
|
||||||
|
|
||||||
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1)
|
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t - 1)
|
||||||
encoding = self.q_sample(x0, t_vec)
|
encoding = self.q_sample(x0, t_vec)
|
||||||
|
|
||||||
img_t = encoding
|
img_t = encoding
|
||||||
|
|
||||||
for k in reversed(range(0,t)):
|
for k in reversed(range(0, t)):
|
||||||
img_t = constrain_fn(img_t, k)
|
img_t = constrain_fn(img_t, k)
|
||||||
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
|
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
|
||||||
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn,
|
img_t = self.p_sample(
|
||||||
clip_denoised=False, return_pred_xstart=False, use_var=True).detach()
|
denoise_fn=denoise_fn,
|
||||||
|
data=img_t,
|
||||||
|
t=t_,
|
||||||
|
noise_fn=noise_fn,
|
||||||
|
clip_denoised=False,
|
||||||
|
return_pred_xstart=False,
|
||||||
|
use_var=True,
|
||||||
|
).detach()
|
||||||
|
|
||||||
return img_t
|
return img_t
|
||||||
|
|
||||||
|
@ -260,40 +286,50 @@ class PVCNN2(PVCNN2Base):
|
||||||
((128, 128, 64), (64, 2, 32)),
|
((128, 128, 64), (64, 2, 32)),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1,
|
def __init__(
|
||||||
voxel_resolution_multiplier=1):
|
self,
|
||||||
|
num_classes,
|
||||||
|
embed_dim,
|
||||||
|
use_att,
|
||||||
|
dropout,
|
||||||
|
extra_feature_channels=3,
|
||||||
|
width_multiplier=1,
|
||||||
|
voxel_resolution_multiplier=1,
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
|
num_classes=num_classes,
|
||||||
dropout=dropout, extra_feature_channels=extra_feature_channels,
|
embed_dim=embed_dim,
|
||||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
|
use_att=use_att,
|
||||||
|
dropout=dropout,
|
||||||
|
extra_feature_channels=extra_feature_channels,
|
||||||
|
width_multiplier=width_multiplier,
|
||||||
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
|
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
|
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
|
||||||
|
|
||||||
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention,
|
self.model = PVCNN2(
|
||||||
dropout=args.dropout, extra_feature_channels=0)
|
num_classes=args.nc,
|
||||||
|
embed_dim=args.embed_dim,
|
||||||
|
use_att=args.attention,
|
||||||
|
dropout=args.dropout,
|
||||||
|
extra_feature_channels=0,
|
||||||
|
)
|
||||||
|
|
||||||
def prior_kl(self, x0):
|
def prior_kl(self, x0):
|
||||||
return self.diffusion._prior_bpd(x0)
|
return self.diffusion._prior_bpd(x0)
|
||||||
|
|
||||||
def all_kl(self, x0, clip_denoised=True):
|
def all_kl(self, x0, clip_denoised=True):
|
||||||
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
|
||||||
|
|
||||||
return {
|
|
||||||
'total_bpd_b': total_bpd_b,
|
|
||||||
'terms_bpd': vals_bt,
|
|
||||||
'prior_bpd_b': prior_bpd_b,
|
|
||||||
'mse_bt':mse_bt
|
|
||||||
}
|
|
||||||
|
|
||||||
|
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
|
||||||
|
|
||||||
def _denoise(self, data, t):
|
def _denoise(self, data, t):
|
||||||
B, D,N= data.shape
|
B, D, N = data.shape
|
||||||
assert data.dtype == torch.float
|
assert data.dtype == torch.float
|
||||||
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
|
||||||
|
|
||||||
|
@ -307,23 +343,34 @@ class Model(nn.Module):
|
||||||
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
|
||||||
|
|
||||||
if noises is not None:
|
if noises is not None:
|
||||||
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
|
noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
|
||||||
|
|
||||||
losses = self.diffusion.p_losses(
|
losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
||||||
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
|
|
||||||
assert losses.shape == t.shape == torch.Size([B])
|
assert losses.shape == t.shape == torch.Size([B])
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x,
|
def gen_samples(
|
||||||
clip_denoised=False, max_timestep=None,
|
self,
|
||||||
keep_running=False):
|
shape,
|
||||||
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn,
|
device,
|
||||||
constrain_fn=constrain_fn,
|
noise_fn=torch.randn,
|
||||||
clip_denoised=clip_denoised, max_timestep=max_timestep,
|
constrain_fn=lambda x, t: x,
|
||||||
keep_running=keep_running)
|
clip_denoised=False,
|
||||||
|
max_timestep=None,
|
||||||
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x):
|
keep_running=False,
|
||||||
|
):
|
||||||
|
return self.diffusion.p_sample_loop(
|
||||||
|
self._denoise,
|
||||||
|
shape=shape,
|
||||||
|
device=device,
|
||||||
|
noise_fn=noise_fn,
|
||||||
|
constrain_fn=constrain_fn,
|
||||||
|
clip_denoised=clip_denoised,
|
||||||
|
max_timestep=max_timestep,
|
||||||
|
keep_running=keep_running,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reconstruct(self, x0, t, constrain_fn=lambda x, t: x):
|
||||||
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
|
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
@ -337,20 +384,17 @@ class Model(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def get_betas(schedule_type, b_start, b_end, time_num):
|
def get_betas(schedule_type, b_start, b_end, time_num):
|
||||||
if schedule_type == 'linear':
|
if schedule_type == "linear":
|
||||||
betas = np.linspace(b_start, b_end, time_num)
|
betas = np.linspace(b_start, b_end, time_num)
|
||||||
elif schedule_type == 'warm0.1':
|
elif schedule_type == "warm0.1":
|
||||||
|
|
||||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||||
warmup_time = int(time_num * 0.1)
|
warmup_time = int(time_num * 0.1)
|
||||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||||
elif schedule_type == 'warm0.2':
|
elif schedule_type == "warm0.2":
|
||||||
|
|
||||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||||
warmup_time = int(time_num * 0.2)
|
warmup_time = int(time_num * 0.2)
|
||||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||||
elif schedule_type == 'warm0.5':
|
elif schedule_type == "warm0.5":
|
||||||
|
|
||||||
betas = b_end * np.ones(time_num, dtype=np.float64)
|
betas = b_end * np.ones(time_num, dtype=np.float64)
|
||||||
warmup_time = int(time_num * 0.5)
|
warmup_time = int(time_num * 0.5)
|
||||||
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
|
||||||
|
@ -358,111 +402,109 @@ def get_betas(schedule_type, b_start, b_end, time_num):
|
||||||
raise NotImplementedError(schedule_type)
|
raise NotImplementedError(schedule_type)
|
||||||
return betas
|
return betas
|
||||||
|
|
||||||
|
|
||||||
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
|
def get_constrain_function(ground_truth, mask, eps, num_steps=1):
|
||||||
'''
|
"""
|
||||||
|
|
||||||
:param target_shape_constraint: target voxels
|
:param target_shape_constraint: target voxels
|
||||||
:return: constrained x
|
:return: constrained x
|
||||||
'''
|
"""
|
||||||
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
|
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
|
||||||
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 ))
|
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000) ** 2))
|
||||||
def constrain_fn(x, t):
|
|
||||||
eps_ = eps_all[t] if (t<1000) else 0
|
|
||||||
for _ in range(num_steps):
|
|
||||||
x = x - eps_ * ((x - ground_truth) * mask)
|
|
||||||
|
|
||||||
|
def constrain_fn(x, t):
|
||||||
|
eps_ = eps_all[t] if (t < 1000) else 0
|
||||||
|
for _ in range(num_steps):
|
||||||
|
x = x - eps_ * ((x - ground_truth) * mask)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
return constrain_fn
|
return constrain_fn
|
||||||
|
|
||||||
|
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
|
||||||
def get_dataset(dataroot, npoints,category,use_mask=False):
|
|
||||||
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
|
def get_dataset(dataroot, npoints, category, use_mask=False):
|
||||||
categories=[category], split='train',
|
tr_dataset = ShapeNet15kPointClouds(
|
||||||
|
root_dir=dataroot,
|
||||||
|
categories=[category],
|
||||||
|
split="train",
|
||||||
tr_sample_size=npoints,
|
tr_sample_size=npoints,
|
||||||
te_sample_size=npoints,
|
te_sample_size=npoints,
|
||||||
scale=1.,
|
scale=1.0,
|
||||||
normalize_per_shape=False,
|
normalize_per_shape=False,
|
||||||
normalize_std_per_axis=False,
|
normalize_std_per_axis=False,
|
||||||
random_subsample=True, use_mask = use_mask)
|
random_subsample=True,
|
||||||
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
|
use_mask=use_mask,
|
||||||
categories=[category], split='val',
|
)
|
||||||
|
te_dataset = ShapeNet15kPointClouds(
|
||||||
|
root_dir=dataroot,
|
||||||
|
categories=[category],
|
||||||
|
split="val",
|
||||||
tr_sample_size=npoints,
|
tr_sample_size=npoints,
|
||||||
te_sample_size=npoints,
|
te_sample_size=npoints,
|
||||||
scale=1.,
|
scale=1.0,
|
||||||
normalize_per_shape=False,
|
normalize_per_shape=False,
|
||||||
normalize_std_per_axis=False,
|
normalize_std_per_axis=False,
|
||||||
all_points_mean=tr_dataset.all_points_mean,
|
all_points_mean=tr_dataset.all_points_mean,
|
||||||
all_points_std=tr_dataset.all_points_std,
|
all_points_std=tr_dataset.all_points_std,
|
||||||
use_mask=use_mask
|
use_mask=use_mask,
|
||||||
)
|
)
|
||||||
return tr_dataset, te_dataset
|
return tr_dataset, te_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_gen(opt, ref_pcs, logger):
|
def evaluate_gen(opt, ref_pcs, logger):
|
||||||
|
|
||||||
if ref_pcs is None:
|
if ref_pcs is None:
|
||||||
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False)
|
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False)
|
||||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
|
test_dataloader = torch.utils.data.DataLoader(
|
||||||
shuffle=False, num_workers=int(opt.workers), drop_last=False)
|
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
||||||
|
)
|
||||||
ref = []
|
ref = []
|
||||||
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'):
|
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
|
||||||
x = data['test_points']
|
x = data["test_points"]
|
||||||
m, s = data['mean'].float(), data['std'].float()
|
m, s = data["mean"].float(), data["std"].float()
|
||||||
|
|
||||||
ref.append(x*s + m)
|
ref.append(x * s + m)
|
||||||
|
|
||||||
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
||||||
|
|
||||||
logger.info("Loading sample path: %s"
|
logger.info("Loading sample path: %s" % (opt.eval_path))
|
||||||
% (opt.eval_path))
|
|
||||||
sample_pcs = torch.load(opt.eval_path).contiguous()
|
sample_pcs = torch.load(opt.eval_path).contiguous()
|
||||||
|
|
||||||
logger.info("Generation sample size:%s reference size: %s"
|
logger.info("Generation sample size:%s reference size: %s" % (sample_pcs.size(), ref_pcs.size()))
|
||||||
% (sample_pcs.size(), ref_pcs.size()))
|
|
||||||
|
|
||||||
|
|
||||||
# Compute metrics
|
# Compute metrics
|
||||||
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
|
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
|
||||||
results = {k: (v.cpu().detach().item()
|
results = {k: (v.cpu().detach().item() if not isinstance(v, float) else v) for k, v in results.items()}
|
||||||
if not isinstance(v, float) else v) for k, v in results.items()}
|
|
||||||
|
|
||||||
pprint(results)
|
pprint(results)
|
||||||
logger.info(results)
|
logger.info(results)
|
||||||
|
|
||||||
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
|
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
|
||||||
pprint('JSD: {}'.format(jsd))
|
pprint("JSD: {}".format(jsd))
|
||||||
logger.info('JSD: {}'.format(jsd))
|
logger.info("JSD: {}".format(jsd))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate(model, opt):
|
def generate(model, opt):
|
||||||
|
|
||||||
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
||||||
|
|
||||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size,
|
test_dataloader = torch.utils.data.DataLoader(
|
||||||
shuffle=False, num_workers=int(opt.workers), drop_last=False)
|
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
ref = []
|
ref = []
|
||||||
|
|
||||||
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'):
|
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Generating Samples"):
|
||||||
|
x = data["test_points"].transpose(1, 2)
|
||||||
x = data['test_points'].transpose(1,2)
|
m, s = data["mean"].float(), data["std"].float()
|
||||||
m, s = data['mean'].float(), data['std'].float()
|
|
||||||
|
|
||||||
gen = model.gen_samples(x.shape,
|
|
||||||
'cuda', clip_denoised=False).detach().cpu()
|
|
||||||
|
|
||||||
gen = gen.transpose(1,2).contiguous()
|
|
||||||
x = x.transpose(1,2).contiguous()
|
|
||||||
|
|
||||||
|
gen = model.gen_samples(x.shape, "cuda", clip_denoised=False).detach().cpu()
|
||||||
|
|
||||||
|
gen = gen.transpose(1, 2).contiguous()
|
||||||
|
x = x.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
gen = gen * s + m
|
gen = gen * s + m
|
||||||
x = x * s + m
|
x = x * s + m
|
||||||
|
@ -482,20 +524,20 @@ def generate(model, opt):
|
||||||
# 1,
|
# 1,
|
||||||
# 0.5,
|
# 0.5,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
# visualize using matplotlib
|
# visualize using matplotlib
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.cm as cm
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('TkAgg')
|
import matplotlib.cm as cm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
matplotlib.use("TkAgg")
|
||||||
for idx, pc in enumerate(gen[:64]):
|
for idx, pc in enumerate(gen[:64]):
|
||||||
print(f"Visualizing point cloud {idx}...")
|
print(f"Visualizing point cloud {idx}...")
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection="3d")
|
||||||
ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,2], cmap=cm.jet)
|
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
|
||||||
ax.set_aspect('equal')
|
ax.set_aspect("equal")
|
||||||
ax.axis('off')
|
ax.axis("off")
|
||||||
# ax.set_xlim(-1, 1)
|
# ax.set_xlim(-1, 1)
|
||||||
# ax.set_ylim(-1, 1)
|
# ax.set_ylim(-1, 1)
|
||||||
# ax.set_zlim(-1, 1)
|
# ax.set_zlim(-1, 1)
|
||||||
|
@ -507,17 +549,14 @@ def generate(model, opt):
|
||||||
|
|
||||||
torch.save(samples, opt.eval_path)
|
torch.save(samples, opt.eval_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return ref
|
return ref
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
|
if opt.category == "airplane":
|
||||||
if opt.category == 'airplane':
|
|
||||||
opt.beta_start = 1e-5
|
opt.beta_start = 1e-5
|
||||||
opt.beta_end = 0.008
|
opt.beta_end = 0.008
|
||||||
opt.schedule_type = 'warm0.1'
|
opt.schedule_type = "warm0.1"
|
||||||
|
|
||||||
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
exp_id = os.path.splitext(os.path.basename(__file__))[0]
|
||||||
dir_id = os.path.dirname(__file__)
|
dir_id = os.path.dirname(__file__)
|
||||||
|
@ -525,7 +564,7 @@ def main(opt):
|
||||||
copy_source(__file__, output_dir)
|
copy_source(__file__, output_dir)
|
||||||
logger = setup_logging(output_dir)
|
logger = setup_logging(output_dir)
|
||||||
|
|
||||||
outf_syn, = setup_output_subdirs(output_dir, 'syn')
|
(outf_syn,) = setup_output_subdirs(output_dir, "syn")
|
||||||
|
|
||||||
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
|
||||||
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
|
||||||
|
@ -542,64 +581,59 @@ def main(opt):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
logger.info("Resume Path:%s" % opt.model)
|
logger.info("Resume Path:%s" % opt.model)
|
||||||
|
|
||||||
resumed_param = torch.load(opt.model)
|
resumed_param = torch.load(opt.model)
|
||||||
model.load_state_dict(resumed_param['model_state'])
|
model.load_state_dict(resumed_param["model_state"])
|
||||||
|
|
||||||
|
|
||||||
ref = None
|
ref = None
|
||||||
if opt.generate:
|
if opt.generate:
|
||||||
opt.eval_path = os.path.join(outf_syn, 'samples.pth')
|
opt.eval_path = os.path.join(outf_syn, "samples.pth")
|
||||||
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
ref=generate(model, opt)
|
ref = generate(model, opt)
|
||||||
|
|
||||||
if opt.eval_gen:
|
if opt.eval_gen:
|
||||||
# Evaluate generation
|
# Evaluate generation
|
||||||
evaluate_gen(opt, ref, logger)
|
evaluate_gen(opt, ref, logger)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/')
|
parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/")
|
||||||
parser.add_argument('--category', default='chair')
|
parser.add_argument("--category", default="chair")
|
||||||
|
|
||||||
parser.add_argument('--batch_size', type=int, default=50, help='input batch size')
|
parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
|
||||||
parser.add_argument('--workers', type=int, default=16, help='workers')
|
parser.add_argument("--workers", type=int, default=16, help="workers")
|
||||||
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for')
|
parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
|
||||||
|
|
||||||
parser.add_argument('--generate',default=True)
|
parser.add_argument("--generate", default=True)
|
||||||
parser.add_argument('--eval_gen', default=True)
|
parser.add_argument("--eval_gen", default=True)
|
||||||
|
|
||||||
parser.add_argument('--nc', default=3)
|
parser.add_argument("--nc", default=3)
|
||||||
parser.add_argument('--npoints', default=2048)
|
parser.add_argument("--npoints", default=2048)
|
||||||
'''model'''
|
"""model"""
|
||||||
parser.add_argument('--beta_start', default=0.0001)
|
parser.add_argument("--beta_start", default=0.0001)
|
||||||
parser.add_argument('--beta_end', default=0.02)
|
parser.add_argument("--beta_end", default=0.02)
|
||||||
parser.add_argument('--schedule_type', default='linear')
|
parser.add_argument("--schedule_type", default="linear")
|
||||||
parser.add_argument('--time_num', default=1000)
|
parser.add_argument("--time_num", default=1000)
|
||||||
|
|
||||||
#params
|
# params
|
||||||
parser.add_argument('--attention', default=True)
|
parser.add_argument("--attention", default=True)
|
||||||
parser.add_argument('--dropout', default=0.1)
|
parser.add_argument("--dropout", default=0.1)
|
||||||
parser.add_argument('--embed_dim', type=int, default=64)
|
parser.add_argument("--embed_dim", type=int, default=64)
|
||||||
parser.add_argument('--loss_type', default='mse')
|
parser.add_argument("--loss_type", default="mse")
|
||||||
parser.add_argument('--model_mean_type', default='eps')
|
parser.add_argument("--model_mean_type", default="eps")
|
||||||
parser.add_argument('--model_var_type', default='fixedsmall')
|
parser.add_argument("--model_var_type", default="fixedsmall")
|
||||||
|
|
||||||
|
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
|
||||||
|
|
||||||
parser.add_argument('--model', default='',required=True, help="path to model (to continue training)")
|
"""eval"""
|
||||||
|
|
||||||
'''eval'''
|
parser.add_argument("--eval_path", default="")
|
||||||
|
|
||||||
parser.add_argument('--eval_path',
|
parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
|
||||||
default='')
|
|
||||||
|
|
||||||
parser.add_argument('--manualSeed', default=42, type=int, help='random seed')
|
parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
|
||||||
|
|
||||||
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
@ -609,7 +643,9 @@ def parse_args():
|
||||||
opt.cuda = False
|
opt.cuda = False
|
||||||
|
|
||||||
return opt
|
return opt
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
opt = parse_args()
|
opt = parse_args()
|
||||||
set_seed(opt)
|
set_seed(opt)
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -1,35 +1,33 @@
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
import datetime
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def set_global_gpu_env(opt):
|
def set_global_gpu_env(opt):
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
|
|
||||||
|
|
||||||
|
|
||||||
torch.cuda.set_device(opt.gpu)
|
torch.cuda.set_device(opt.gpu)
|
||||||
|
|
||||||
|
|
||||||
def copy_source(file, output_dir):
|
def copy_source(file, output_dir):
|
||||||
copyfile(file, os.path.join(output_dir, os.path.basename(file)))
|
copyfile(file, os.path.join(output_dir, os.path.basename(file)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(output_dir):
|
def setup_logging(output_dir):
|
||||||
log_format = logging.Formatter("%(asctime)s : %(message)s")
|
log_format = logging.Formatter("%(asctime)s : %(message)s")
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.handlers = []
|
logger.handlers = []
|
||||||
output_file = os.path.join(output_dir, 'output.log')
|
output_file = os.path.join(output_dir, "output.log")
|
||||||
file_handler = logging.FileHandler(output_file)
|
file_handler = logging.FileHandler(output_file)
|
||||||
file_handler.setFormatter(log_format)
|
file_handler.setFormatter(log_format)
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
@ -44,16 +42,14 @@ def setup_logging(output_dir):
|
||||||
|
|
||||||
|
|
||||||
def get_output_dir(prefix, exp_id):
|
def get_output_dir(prefix, exp_id):
|
||||||
t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
output_dir = os.path.join(prefix, 'output/' + exp_id, t)
|
output_dir = os.path.join(prefix, "output/" + exp_id, t)
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
return output_dir
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(opt):
|
def set_seed(opt):
|
||||||
|
|
||||||
if opt.manualSeed is None:
|
if opt.manualSeed is None:
|
||||||
opt.manualSeed = random.randint(1, 10000)
|
opt.manualSeed = random.randint(1, 10000)
|
||||||
print("Random Seed: ", opt.manualSeed)
|
print("Random Seed: ", opt.manualSeed)
|
||||||
|
@ -65,8 +61,8 @@ def set_seed(opt):
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
def setup_output_subdirs(output_dir, *subfolders):
|
|
||||||
|
|
||||||
|
def setup_output_subdirs(output_dir, *subfolders):
|
||||||
output_subdirs = output_dir
|
output_subdirs = output_dir
|
||||||
try:
|
try:
|
||||||
os.makedirs(output_subdirs)
|
os.makedirs(output_subdirs)
|
||||||
|
@ -82,4 +78,4 @@ def setup_output_subdirs(output_dir, *subfolders):
|
||||||
pass
|
pass
|
||||||
subfolder_list.append(curr_subf)
|
subfolder_list.append(curr_subf)
|
||||||
|
|
||||||
return subfolder_list
|
return subfolder_list
|
||||||
|
|
|
@ -1,20 +1,22 @@
|
||||||
import numpy as np
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from scipy.stats import entropy
|
from scipy.stats import entropy
|
||||||
|
|
||||||
|
|
||||||
def iterate_in_chunks(l, n):
|
def iterate_in_chunks(l, n):
|
||||||
'''Yield successive 'n'-sized chunks from iterable 'l'.
|
"""Yield successive 'n'-sized chunks from iterable 'l'.
|
||||||
Note: last chunk will be smaller than l if n doesn't divide l perfectly.
|
Note: last chunk will be smaller than l if n doesn't divide l perfectly.
|
||||||
'''
|
"""
|
||||||
for i in range(0, len(l), n):
|
for i in range(0, len(l), n):
|
||||||
yield l[i:i + n]
|
yield l[i : i + n]
|
||||||
|
|
||||||
|
|
||||||
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
||||||
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
|
||||||
that is placed in the unit-cube.
|
that is placed in the unit-cube.
|
||||||
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
|
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
|
||||||
'''
|
"""
|
||||||
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
|
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
|
||||||
spacing = 1.0 / float(resolution - 1)
|
spacing = 1.0 / float(resolution - 1)
|
||||||
for i in range(resolution):
|
for i in range(resolution):
|
||||||
|
@ -30,9 +32,11 @@ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
|
||||||
|
|
||||||
return grid, spacing
|
return grid, spacing
|
||||||
|
|
||||||
def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False,
|
|
||||||
use_EMD=False):
|
def minimum_mathing_distance(
|
||||||
'''Computes the MMD between two sets of point-clouds.
|
sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False
|
||||||
|
):
|
||||||
|
"""Computes the MMD between two sets of point-clouds.
|
||||||
Args:
|
Args:
|
||||||
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and
|
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and
|
||||||
compared to a set of "reference" point-clouds.
|
compared to a set of "reference" point-clouds.
|
||||||
|
@ -49,17 +53,17 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
|
||||||
use_EMD (boolean: If true, the matchings are based on the EMD.
|
use_EMD (boolean: If true, the matchings are based on the EMD.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing the MMD and all the matched distances of which the MMD is their mean.
|
A tuple containing the MMD and all the matched distances of which the MMD is their mean.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
n_ref, n_pc_points, pc_dim = ref_pcs.shape
|
n_ref, n_pc_points, pc_dim = ref_pcs.shape
|
||||||
_, n_pc_points_s, pc_dim_s = sample_pcs.shape
|
_, n_pc_points_s, pc_dim_s = sample_pcs.shape
|
||||||
|
|
||||||
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
|
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
|
||||||
raise ValueError('Incompatible size of point-clouds.')
|
raise ValueError("Incompatible size of point-clouds.")
|
||||||
|
|
||||||
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize,
|
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(
|
||||||
sess=sess, use_sqrt=use_sqrt,
|
n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
|
||||||
use_EMD=use_EMD)
|
)
|
||||||
matched_dists = []
|
matched_dists = []
|
||||||
for i in range(n_ref):
|
for i in range(n_ref):
|
||||||
best_in_all_batches = []
|
best_in_all_batches = []
|
||||||
|
@ -75,9 +79,18 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
|
||||||
return mmd, matched_dists
|
return mmd, matched_dists
|
||||||
|
|
||||||
|
|
||||||
def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False,
|
def coverage(
|
||||||
ret_dist=False):
|
sample_pcs,
|
||||||
'''Computes the Coverage between two sets of point-clouds.
|
ref_pcs,
|
||||||
|
batch_size,
|
||||||
|
normalize=True,
|
||||||
|
sess=None,
|
||||||
|
verbose=False,
|
||||||
|
use_sqrt=False,
|
||||||
|
use_EMD=False,
|
||||||
|
ret_dist=False,
|
||||||
|
):
|
||||||
|
"""Computes the Coverage between two sets of point-clouds.
|
||||||
Args:
|
Args:
|
||||||
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched
|
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched
|
||||||
and compared to a set of "reference" point-clouds.
|
and compared to a set of "reference" point-clouds.
|
||||||
|
@ -97,18 +110,16 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
|
||||||
Returns: the coverage score (int),
|
Returns: the coverage score (int),
|
||||||
the indices of the ref_pcs that are matched with each sample_pc
|
the indices of the ref_pcs that are matched with each sample_pc
|
||||||
and optionally the matched distances of the samples_pcs.
|
and optionally the matched distances of the samples_pcs.
|
||||||
'''
|
"""
|
||||||
n_ref, n_pc_points, pc_dim = ref_pcs.shape
|
n_ref, n_pc_points, pc_dim = ref_pcs.shape
|
||||||
n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape
|
n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape
|
||||||
|
|
||||||
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
|
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
|
||||||
raise ValueError('Incompatible Point-Clouds.')
|
raise ValueError("Incompatible Point-Clouds.")
|
||||||
|
|
||||||
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points,
|
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(
|
||||||
normalize=normalize,
|
n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
|
||||||
sess=sess,
|
)
|
||||||
use_sqrt=use_sqrt,
|
|
||||||
use_EMD=use_EMD)
|
|
||||||
matched_gt = []
|
matched_gt = []
|
||||||
matched_dist = []
|
matched_dist = []
|
||||||
for i in xrange(n_sam):
|
for i in xrange(n_sam):
|
||||||
|
@ -140,12 +151,12 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
|
||||||
|
|
||||||
|
|
||||||
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
|
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
|
||||||
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
|
"""Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
|
||||||
Args:
|
Args:
|
||||||
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
|
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
|
||||||
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
|
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
|
||||||
resolution: (int) grid-resolution. Affects granularity of measurements.
|
resolution: (int) grid-resolution. Affects granularity of measurements.
|
||||||
'''
|
"""
|
||||||
in_unit_sphere = True
|
in_unit_sphere = True
|
||||||
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
|
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
|
||||||
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
|
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
|
||||||
|
@ -153,19 +164,19 @@ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
|
||||||
|
|
||||||
|
|
||||||
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
||||||
'''Given a collection of point-clouds, estimate the entropy of the random variables
|
"""Given a collection of point-clouds, estimate the entropy of the random variables
|
||||||
corresponding to occupancy-grid activation patterns.
|
corresponding to occupancy-grid activation patterns.
|
||||||
Inputs:
|
Inputs:
|
||||||
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
|
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
|
||||||
grid_resolution (int) size of occupancy grid that will be used.
|
grid_resolution (int) size of occupancy grid that will be used.
|
||||||
'''
|
"""
|
||||||
epsilon = 10e-4
|
epsilon = 10e-4
|
||||||
bound = 0.5 + epsilon
|
bound = 0.5 + epsilon
|
||||||
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
|
||||||
warnings.warn('Point-clouds are not in unit cube.')
|
warnings.warn("Point-clouds are not in unit cube.")
|
||||||
|
|
||||||
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
|
if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
|
||||||
warnings.warn('Point-clouds are not in unit sphere.')
|
warnings.warn("Point-clouds are not in unit sphere.")
|
||||||
|
|
||||||
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
|
||||||
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
grid_coordinates = grid_coordinates.reshape(-1, 3)
|
||||||
|
@ -192,13 +203,14 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
|
||||||
|
|
||||||
return acc_entropy / len(grid_counters), grid_counters
|
return acc_entropy / len(grid_counters), grid_counters
|
||||||
|
|
||||||
|
|
||||||
def jensen_shannon_divergence(P, Q):
|
def jensen_shannon_divergence(P, Q):
|
||||||
if np.any(P < 0) or np.any(Q < 0):
|
if np.any(P < 0) or np.any(Q < 0):
|
||||||
raise ValueError('Negative values.')
|
raise ValueError("Negative values.")
|
||||||
if len(P) != len(Q):
|
if len(P) != len(Q):
|
||||||
raise ValueError('Non equal size.')
|
raise ValueError("Non equal size.")
|
||||||
|
|
||||||
P_ = P / np.sum(P) # Ensure probabilities.
|
P_ = P / np.sum(P) # Ensure probabilities.
|
||||||
Q_ = Q / np.sum(Q)
|
Q_ = Q / np.sum(Q)
|
||||||
|
|
||||||
e1 = entropy(P_, base=2)
|
e1 = entropy(P_, base=2)
|
||||||
|
@ -209,13 +221,14 @@ def jensen_shannon_divergence(P, Q):
|
||||||
res2 = _jsdiv(P_, Q_)
|
res2 = _jsdiv(P_, Q_)
|
||||||
|
|
||||||
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
if not np.allclose(res, res2, atol=10e-5, rtol=0):
|
||||||
warnings.warn('Numerical values of two JSD methods don\'t agree.')
|
warnings.warn("Numerical values of two JSD methods don't agree.")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _jsdiv(P, Q):
|
def _jsdiv(P, Q):
|
||||||
'''another way of computing JSD'''
|
"""another way of computing JSD"""
|
||||||
|
|
||||||
def _kldiv(A, B):
|
def _kldiv(A, B):
|
||||||
a = A.copy()
|
a = A.copy()
|
||||||
b = B.copy()
|
b = B.copy()
|
||||||
|
@ -229,4 +242,4 @@ def _jsdiv(P, Q):
|
||||||
|
|
||||||
M = 0.5 * (P_ + Q_)
|
M = 0.5 * (P_ + Q_)
|
||||||
|
|
||||||
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
|
||||||
|
|
|
@ -1,32 +1,33 @@
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('agg')
|
|
||||||
import matplotlib.pyplot as plt
|
matplotlib.use("agg")
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
|
||||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import trimesh
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
'''
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import trimesh
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
|
||||||
|
"""
|
||||||
Custom visualization
|
Custom visualization
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def export_to_pc_batch(dir, pcs, colors=None):
|
def export_to_pc_batch(dir, pcs, colors=None):
|
||||||
|
|
||||||
Path(dir).mkdir(parents=True, exist_ok=True)
|
Path(dir).mkdir(parents=True, exist_ok=True)
|
||||||
for i, xyz in enumerate(pcs):
|
for i, xyz in enumerate(pcs):
|
||||||
if colors is None:
|
if colors is None:
|
||||||
color = None
|
color = None
|
||||||
else:
|
else:
|
||||||
color = colors[i]
|
color = colors[i]
|
||||||
pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color)
|
pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color)
|
||||||
|
|
||||||
|
|
||||||
def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
|
def export_to_obj(dir, meshes, transform=lambda v, f: (v, f)):
|
||||||
'''
|
"""
|
||||||
transform: f(vertices, faces) --> transformed (vertices, faces)
|
transform: f(vertices, faces) --> transformed (vertices, faces)
|
||||||
'''
|
"""
|
||||||
Path(dir).mkdir(parents=True, exist_ok=True)
|
Path(dir).mkdir(parents=True, exist_ok=True)
|
||||||
for i, data in enumerate(meshes):
|
for i, data in enumerate(meshes):
|
||||||
v, f = transform(data[0], data[1])
|
v, f = transform(data[0], data[1])
|
||||||
|
@ -36,14 +37,15 @@ def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
|
||||||
v_color = None
|
v_color = None
|
||||||
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
||||||
out = trimesh.exchange.obj.export_obj(mesh)
|
out = trimesh.exchange.obj.export_obj(mesh)
|
||||||
with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f:
|
with open(os.path.join(dir, "sample_" + str(i) + ".obj"), "w") as f:
|
||||||
f.write(out)
|
f.write(out)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
|
|
||||||
'''
|
def export_to_obj_single(path, data, transform=lambda v, f: (v, f)):
|
||||||
|
"""
|
||||||
transform: f(vertices, faces) --> transformed (vertices, faces)
|
transform: f(vertices, faces) --> transformed (vertices, faces)
|
||||||
'''
|
"""
|
||||||
v, f = transform(data[0], data[1])
|
v, f = transform(data[0], data[1])
|
||||||
if len(data) > 2:
|
if len(data) > 2:
|
||||||
v_color = data[2]
|
v_color = data[2]
|
||||||
|
@ -51,15 +53,15 @@ def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
|
||||||
v_color = None
|
v_color = None
|
||||||
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
|
||||||
out = trimesh.exchange.obj.export_obj(mesh)
|
out = trimesh.exchange.obj.export_obj(mesh)
|
||||||
with open(path, 'w') as f:
|
with open(path, "w") as f:
|
||||||
f.write(out)
|
f.write(out)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
def meshwrite(filename, verts, faces, norms, colors):
|
def meshwrite(filename, verts, faces, norms, colors):
|
||||||
"""Save a 3D mesh to a polygon .ply file.
|
"""Save a 3D mesh to a polygon .ply file."""
|
||||||
"""
|
|
||||||
# Write header
|
# Write header
|
||||||
ply_file = open(filename, 'w')
|
ply_file = open(filename, "w")
|
||||||
ply_file.write("ply\n")
|
ply_file.write("ply\n")
|
||||||
ply_file.write("format ascii 1.0\n")
|
ply_file.write("format ascii 1.0\n")
|
||||||
ply_file.write("element vertex %d\n" % (verts.shape[0]))
|
ply_file.write("element vertex %d\n" % (verts.shape[0]))
|
||||||
|
@ -78,11 +80,20 @@ def meshwrite(filename, verts, faces, norms, colors):
|
||||||
|
|
||||||
# Write vertex list
|
# Write vertex list
|
||||||
for i in range(verts.shape[0]):
|
for i in range(verts.shape[0]):
|
||||||
ply_file.write("%f %f %f %f %f %f %d %d %d\n" % (
|
ply_file.write(
|
||||||
verts[i, 0], verts[i, 1], verts[i, 2],
|
"%f %f %f %f %f %f %d %d %d\n"
|
||||||
norms[i, 0], norms[i, 1], norms[i, 2],
|
% (
|
||||||
colors[i, 0], colors[i, 1], colors[i, 2],
|
verts[i, 0],
|
||||||
))
|
verts[i, 1],
|
||||||
|
verts[i, 2],
|
||||||
|
norms[i, 0],
|
||||||
|
norms[i, 1],
|
||||||
|
norms[i, 2],
|
||||||
|
colors[i, 0],
|
||||||
|
colors[i, 1],
|
||||||
|
colors[i, 2],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Write face list
|
# Write face list
|
||||||
for i in range(faces.shape[0]):
|
for i in range(faces.shape[0]):
|
||||||
|
@ -92,14 +103,13 @@ def meshwrite(filename, verts, faces, norms, colors):
|
||||||
|
|
||||||
|
|
||||||
def pcwrite(filename, xyz, rgb=None):
|
def pcwrite(filename, xyz, rgb=None):
|
||||||
"""Save a point cloud to a polygon .ply file.
|
"""Save a point cloud to a polygon .ply file."""
|
||||||
"""
|
|
||||||
if rgb is None:
|
if rgb is None:
|
||||||
rgb = np.ones_like(xyz) * 128
|
rgb = np.ones_like(xyz) * 128
|
||||||
rgb = rgb.astype(np.uint8)
|
rgb = rgb.astype(np.uint8)
|
||||||
|
|
||||||
# Write header
|
# Write header
|
||||||
ply_file = open(filename, 'w')
|
ply_file = open(filename, "w")
|
||||||
ply_file.write("ply\n")
|
ply_file.write("ply\n")
|
||||||
ply_file.write("format ascii 1.0\n")
|
ply_file.write("format ascii 1.0\n")
|
||||||
ply_file.write("element vertex %d\n" % (xyz.shape[0]))
|
ply_file.write("element vertex %d\n" % (xyz.shape[0]))
|
||||||
|
@ -113,60 +123,67 @@ def pcwrite(filename, xyz, rgb=None):
|
||||||
|
|
||||||
# Write vertex list
|
# Write vertex list
|
||||||
for i in range(xyz.shape[0]):
|
for i in range(xyz.shape[0]):
|
||||||
ply_file.write("%f %f %f %d %d %d\n" % (
|
ply_file.write(
|
||||||
xyz[i, 0], xyz[i, 1], xyz[i, 2],
|
"%f %f %f %d %d %d\n"
|
||||||
rgb[i, 0], rgb[i, 1], rgb[i, 2],
|
% (
|
||||||
))
|
xyz[i, 0],
|
||||||
|
xyz[i, 1],
|
||||||
|
xyz[i, 2],
|
||||||
|
rgb[i, 0],
|
||||||
|
rgb[i, 1],
|
||||||
|
rgb[i, 2],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
'''
|
|
||||||
|
"""
|
||||||
Matplotlib Visualization
|
Matplotlib Visualization
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
|
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
|
||||||
r''' Visualizes voxel data.
|
r"""Visualizes voxel data.
|
||||||
show only first num_shown
|
show only first num_shown
|
||||||
'''
|
"""
|
||||||
batch_size =voxels.shape[0]
|
batch_size = voxels.shape[0]
|
||||||
voxels = voxels.squeeze(1) > threshold
|
voxels = voxels.squeeze(1) > threshold
|
||||||
|
|
||||||
num_shown = min(num_shown, batch_size)
|
num_shown = min(num_shown, batch_size)
|
||||||
|
|
||||||
n = int(np.sqrt(num_shown))
|
n = int(np.sqrt(num_shown))
|
||||||
fig = plt.figure(figsize=(20,20))
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
|
||||||
for idx, pc in enumerate(voxels[:num_shown]):
|
for idx, pc in enumerate(voxels[:num_shown]):
|
||||||
if idx >= n*n:
|
if idx >= n * n:
|
||||||
break
|
break
|
||||||
pc = voxels[idx]
|
pc = voxels[idx]
|
||||||
ax = fig.add_subplot(n, n, idx + 1, projection='3d')
|
ax = fig.add_subplot(n, n, idx + 1, projection="3d")
|
||||||
ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5)
|
ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5)
|
||||||
ax.view_init()
|
ax.view_init()
|
||||||
ax.axis('off')
|
ax.axis("off")
|
||||||
plt.savefig(out_file, bbox_inches='tight')
|
plt.savefig(out_file, bbox_inches="tight")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
def visualize_pointcloud(points, normals=None,
|
|
||||||
out_file=None, show=False, elev=30, azim=225):
|
def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225):
|
||||||
r''' Visualizes point cloud data.
|
r"""Visualizes point cloud data.
|
||||||
Args:
|
Args:
|
||||||
points (tensor): point data
|
points (tensor): point data
|
||||||
normals (tensor): normal data (if existing)
|
normals (tensor): normal data (if existing)
|
||||||
out_file (string): output file
|
out_file (string): output file
|
||||||
show (bool): whether the plot should be shown
|
show (bool): whether the plot should be shown
|
||||||
'''
|
"""
|
||||||
# Create plot
|
# Create plot
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.gca(projection=Axes3D.name)
|
ax = fig.gca(projection=Axes3D.name)
|
||||||
ax.scatter(points[:, 2], points[:, 0], points[:, 1])
|
ax.scatter(points[:, 2], points[:, 0], points[:, 1])
|
||||||
if normals is not None:
|
if normals is not None:
|
||||||
ax.quiver(
|
ax.quiver(
|
||||||
points[:, 2], points[:, 0], points[:, 1],
|
points[:, 2], points[:, 0], points[:, 1], normals[:, 2], normals[:, 0], normals[:, 1], length=0.1, color="k"
|
||||||
normals[:, 2], normals[:, 0], normals[:, 1],
|
|
||||||
length=0.1, color='k'
|
|
||||||
)
|
)
|
||||||
ax.set_xlabel('Z')
|
ax.set_xlabel("Z")
|
||||||
ax.set_ylabel('X')
|
ax.set_ylabel("X")
|
||||||
ax.set_zlabel('Y')
|
ax.set_zlabel("Y")
|
||||||
# ax.set_xlim(-0.5, 0.5)
|
# ax.set_xlim(-0.5, 0.5)
|
||||||
# ax.set_ylim(-0.5, 0.5)
|
# ax.set_ylim(-0.5, 0.5)
|
||||||
# ax.set_zlim(-0.5, 0.5)
|
# ax.set_zlim(-0.5, 0.5)
|
||||||
|
@ -178,37 +195,39 @@ def visualize_pointcloud(points, normals=None,
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225):
|
def visualize_pointcloud_batch(
|
||||||
|
path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225
|
||||||
|
):
|
||||||
batch_size = len(pointclouds)
|
batch_size = len(pointclouds)
|
||||||
fig = plt.figure(figsize=(20,20))
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
|
||||||
ncols = int(np.sqrt(batch_size))
|
ncols = int(np.sqrt(batch_size))
|
||||||
nrows = max(1, (batch_size-1) // ncols+1)
|
nrows = max(1, (batch_size - 1) // ncols + 1)
|
||||||
for idx, pc in enumerate(pointclouds):
|
for idx, pc in enumerate(pointclouds):
|
||||||
if vis_label:
|
if vis_label:
|
||||||
label = categories[labels[idx].item()]
|
label = categories[labels[idx].item()]
|
||||||
pred = categories[pred_labels[idx]]
|
pred = categories[pred_labels[idx]]
|
||||||
colour = 'g' if label == pred else 'r'
|
colour = "g" if label == pred else "r"
|
||||||
elif target is None:
|
elif target is None:
|
||||||
|
colour = "g"
|
||||||
colour = 'g'
|
|
||||||
else:
|
else:
|
||||||
colour = target[idx]
|
colour = target[idx]
|
||||||
pc = pc.cpu().numpy()
|
pc = pc.cpu().numpy()
|
||||||
ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d')
|
ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d")
|
||||||
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
|
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
|
||||||
ax.view_init(elev=elev, azim=azim)
|
ax.view_init(elev=elev, azim=azim)
|
||||||
ax.axis('off')
|
ax.axis("off")
|
||||||
if vis_label:
|
if vis_label:
|
||||||
ax.set_title('GT: {0}\nPred: {1}'.format(label, pred))
|
ax.set_title("GT: {0}\nPred: {1}".format(label, pred))
|
||||||
|
|
||||||
plt.savefig(path)
|
plt.savefig(path)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
Plot stats
|
Plot stats
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def plot_stats(output_dir, stats, interval):
|
def plot_stats(output_dir, stats, interval):
|
||||||
content = stats.keys()
|
content = stats.keys()
|
||||||
|
@ -218,5 +237,5 @@ def plot_stats(output_dir, stats, interval):
|
||||||
axs[j].plot(interval, v)
|
axs[j].plot(interval, v)
|
||||||
axs[j].set_ylabel(k)
|
axs[j].set_ylabel(k)
|
||||||
|
|
||||||
f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight')
|
f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight")
|
||||||
plt.close(f)
|
plt.close(f)
|
||||||
|
|
Loading…
Reference in a new issue