style: autoformatting

This commit is contained in:
Laurent FAINSIN 2023-04-11 11:12:58 +02:00
parent d887d74852
commit 2fbfc320f2
44 changed files with 2424 additions and 1831 deletions

View file

@ -1,32 +1,33 @@
from glob import glob
import re
import argparse import argparse
import numpy as np
from pathlib import Path
import os import os
import re
from glob import glob
from pathlib import Path
import numpy as np
def raw_camparam_from_xml(path, pose="lookAt"): def raw_camparam_from_xml(path, pose="lookAt"):
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
tree = ET.parse(path) tree = ET.parse(path)
elm = tree.find("./sensor/transform/" + pose) elm = tree.find("./sensor/transform/" + pose)
camparam = elm.attrib camparam = elm.attrib
origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',') origin = np.fromstring(camparam["origin"], dtype=np.float32, sep=",")
target = np.fromstring(camparam['target'], dtype=np.float32, sep=',') target = np.fromstring(camparam["target"], dtype=np.float32, sep=",")
up = np.fromstring(camparam['up'], dtype=np.float32, sep=',') up = np.fromstring(camparam["up"], dtype=np.float32, sep=",")
height = int( height = int(tree.find("./sensor/film/integer[@name='height']").attrib["value"])
tree.find("./sensor/film/integer[@name='height']").attrib['value']) width = int(tree.find("./sensor/film/integer[@name='width']").attrib["value"])
width = int(
tree.find("./sensor/film/integer[@name='width']").attrib['value'])
camparam = dict() camparam = dict()
camparam['origin'] = origin camparam["origin"] = origin
camparam['up'] = up camparam["up"] = up
camparam['target'] = target camparam["target"] = target
camparam['height'] = height camparam["height"] = height
camparam['width'] = width camparam["width"] = width
return camparam return camparam
def get_cam_pos(origin, target, up): def get_cam_pos(origin, target, up):
inward = origin - target inward = origin - target
right = np.cross(up, inward) right = np.cross(up, inward)
@ -38,59 +39,54 @@ def get_cam_pos(origin, target, up):
ry /= np.linalg.norm(ry) ry /= np.linalg.norm(ry)
rz /= np.linalg.norm(rz) rz /= np.linalg.norm(rz)
rot = np.stack([ rot = np.stack([rx, ry, -rz], axis=0)
rx,
ry,
-rz
], axis=0)
aff = np.concatenate([
np.eye(3), -origin[:,None]
], axis=1)
aff = np.concatenate([np.eye(3), -origin[:, None]], axis=1)
ext = np.matmul(rot, aff) ext = np.matmul(rot, aff)
result = np.concatenate( result = np.concatenate([ext, np.array([[0, 0, 0, 1]])], axis=0)
[ext, np.array([[0,0,0,1]])], axis=0
)
return result return result
def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir): def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png'))) depths = sorted(glob(os.path.join(datapoint_dir, "*depth.png")))
cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths] cam_ext = [
"_".join(re.sub(dataroot.strip("/"), camera_param_dir.strip("/"), f).split("_")[:-1]) + ".xml" for f in depths
]
for i, (f, pth) in enumerate(zip(cam_ext, depths)): for i, (f, pth) in enumerate(zip(cam_ext, depths)):
if not os.path.exists(f): if not os.path.exists(f):
continue continue
params=raw_camparam_from_xml(f) params = raw_camparam_from_xml(f)
origin, target, up, width, height = params['origin'], params['target'], params['up'],\ origin, target, up, width, height = (
params['width'], params['height'] params["origin"],
params["target"],
params["up"],
params["width"],
params["height"],
)
ext_matrix = get_cam_pos(origin, target, up) ext_matrix = get_cam_pos(origin, target, up)
##### #####
diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5 diag = (0.036**2 + 0.024**2) ** 0.5
focal_length = 0.05 focal_length = 0.05
res = [480, 480] res = [480, 480]
h_relative = (res[1] / res[0]) h_relative = res[1] / res[0]
sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2)) sensor_width = np.sqrt(diag**2 / (1 + h_relative**2))
pix_size = sensor_width / res[0] pix_size = sensor_width / res[0]
K = np.array([ K = np.array(
[focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2], [
[0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2], [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
[0, 0, 1] [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
]) [0, 0, 1],
]
)
np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K) np.savez(pth.split("depth.png")[0] + "cam_params.npz", extr=ext_matrix, intr=K)
def main(opt): def main(opt):
@ -102,21 +98,16 @@ def main(opt):
if (not dirnames) and opt.mitsuba_xml_root not in dirpath: if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
leaf_subdirs.append(dirpath) leaf_subdirs.append(dirpath)
for k, dir_ in enumerate(leaf_subdirs): for k, dir_ in enumerate(leaf_subdirs):
print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_)) print("Processing dir {}/{}: {}".format(k, len(leaf_subdirs), dir_))
convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root) convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
if __name__ == "__main__":
if __name__ == '__main__':
args = argparse.ArgumentParser() args = argparse.ArgumentParser()
args.add_argument('--dataroot', type=str, default='GenReData/') args.add_argument("--dataroot", type=str, default="GenReData/")
args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2') args.add_argument("--mitsuba_xml_root", type=str, default="GenReData/genre-xml_v2")
opt = args.parse_args() opt = args.parse_args()

View file

@ -1,11 +1,13 @@
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import os
import json import json
import os
import random import random
import numpy as np
import torch
import trimesh import trimesh
from plyfile import PlyData, PlyElement from plyfile import PlyData, PlyElement
from torch.utils.data import Dataset
def project_pc_to_image(points, resolution=64): def project_pc_to_image(points, resolution=64):
"""project point clouds into 2D image """project point clouds into 2D image
@ -26,29 +28,32 @@ def project_pc_to_image(points, resolution=64):
def write_ply(points, filename, text=False): def write_ply(points, filename, text=False):
""" input: Nx3, write points to filename as PLY format. """ """input: Nx3, write points to filename as PLY format."""
points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) el = PlyElement.describe(vertex, "vertex", comments=["vertices"])
with open(filename, mode='wb') as f: with open(filename, mode="wb") as f:
PlyData([el], text=text).write(f) PlyData([el], text=text).write(f)
def rotate_point_cloud(points, transformation_mat): def rotate_point_cloud(points, transformation_mat):
new_points = np.dot(transformation_mat, points.T).T new_points = np.dot(transformation_mat, points.T).T
return new_points return new_points
def rotate_point_cloud_by_axis_angle(points, axis, angle_deg): def rotate_point_cloud_by_axis_angle(points, axis, angle_deg):
""" align 3depn shapes to shapenet coordinates""" """align 3depn shapes to shapenet coordinates"""
# angle = math.radians(angle_deg) # angle = math.radians(angle_deg)
# rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle) # rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle)
# rot_m = rot_m.to_matrix() # rot_m = rot_m.to_matrix()
rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00], rot_m = np.array(
[ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00], [
[-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]]) [2.22044605e-16, 0.00000000e00, 1.00000000e00],
[0.00000000e00, 1.00000000e00, 0.00000000e00],
[-1.00000000e00, 0.00000000e00, 2.22044605e-16],
]
)
new_points = rotate_point_cloud(points, rot_m) new_points = rotate_point_cloud(points, rot_m)
@ -87,14 +92,13 @@ def sample_point_cloud_by_n(points, n_pts):
return points return points
def collect_data_id(split_dir, classname, phase): def collect_data_id(split_dir, classname, phase):
filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase)) filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase))
if not os.path.exists(filename): if not os.path.exists(filename):
raise ValueError("Invalid filepath: {}".format(filename)) raise ValueError("Invalid filepath: {}".format(filename))
all_ids = [] all_ids = []
with open(filename, 'r') as fp: with open(filename, "r") as fp:
info = json.load(fp) info = json.load(fp)
for item in info: for item in info:
all_ids.append(item["anno_id"]) all_ids.append(item["anno_id"])
@ -102,7 +106,6 @@ def collect_data_id(split_dir, classname, phase):
return all_ids return all_ids
class GANdatasetPartNet(Dataset): class GANdatasetPartNet(Dataset):
def __init__(self, phase, data_root, category, n_pts): def __init__(self, phase, data_root, category, n_pts):
super(GANdatasetPartNet, self).__init__() super(GANdatasetPartNet, self).__init__()
@ -114,10 +117,12 @@ class GANdatasetPartNet(Dataset):
self.data_root = data_root self.data_root = data_root
shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase) shape_names = collect_data_id(
os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase
)
self.shape_names = [] self.shape_names = []
for name in shape_names: for name in shape_names:
path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name) path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name)
if os.path.exists(path): if os.path.exists(path):
self.shape_names.append(name) self.shape_names.append(name)
@ -129,12 +134,12 @@ class GANdatasetPartNet(Dataset):
@staticmethod @staticmethod
def load_point_cloud(path): def load_point_cloud(path):
pc = trimesh.load(path) pc = trimesh.load(path)
pc = pc.vertices / 2.0 # scale to unit sphere pc = pc.vertices / 2.0 # scale to unit sphere
return pc return pc
@staticmethod @staticmethod
def read_point_cloud_part_label(path): def read_point_cloud_part_label(path):
with open(path, 'r') as fp: with open(path, "r") as fp:
labels = fp.readlines() labels = fp.readlines()
labels = np.array([int(x) for x in labels]) labels = np.array([int(x) for x in labels])
return labels return labels
@ -156,26 +161,31 @@ class GANdatasetPartNet(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
raw_shape_name = self.shape_names[index] raw_shape_name = self.shape_names[index]
raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply') raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply")
raw_pc = self.load_point_cloud(raw_ply_path) raw_pc = self.load_point_cloud(raw_ply_path)
raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt') raw_label_path = os.path.join(
self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt"
)
part_labels = self.read_point_cloud_part_label(raw_label_path) part_labels = self.read_point_cloud_part_label(raw_label_path)
raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels) raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels)
raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts) raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts)
raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0) raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0)
real_shape_name = self.shape_names[index] real_shape_name = self.shape_names[index]
real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply') real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply")
real_pc = self.load_point_cloud(real_ply_path) real_pc = self.load_point_cloud(real_ply_path)
real_pc = sample_point_cloud_by_n(real_pc, self.n_pts) real_pc = sample_point_cloud_by_n(real_pc, self.n_pts)
real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0) real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0)
return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name, return {
'n_part_keep': n_part_keep, 'idx': index} "raw": raw_pc,
"real": real_pc,
"raw_id": raw_shape_name,
"real_id": real_shape_name,
"n_part_keep": n_part_keep,
"idx": index,
}
def __len__(self): def __len__(self):
return len(self.shape_names) return len(self.shape_names)

View file

@ -1,33 +1,67 @@
import os import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils import data
import random import random
import numpy as np import numpy as np
import torch.nn.functional as F import torch
from torch.utils.data import Dataset
# taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py # taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = { synsetid_to_cate = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', "02691156": "airplane",
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', "02773838": "bag",
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', "02801938": "basket",
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', "02808440": "bathtub",
'02954340': 'cap', '02958343': 'car', '03001627': 'chair', "02818832": "bed",
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', "02828884": "bench",
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', "02876657": "bottle",
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', "02880940": "bowl",
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', "02924116": "bus",
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', "02933112": "cabinet",
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', "02747177": "can",
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', "02942699": "camera",
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', "02954340": "cap",
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', "02958343": "car",
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', "03001627": "chair",
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', "03046257": "clock",
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', "03207941": "dishwasher",
'04554684': 'washer', '02992529': 'cellphone', "03211117": "monitor",
'02843684': 'birdhouse', '02871439': 'bookshelf', "04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
# '02858304': 'boat', no boat in our dataset, merged into vessels # '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy # '02834778': 'bicycle', not in our taxonomy
} }
@ -35,13 +69,23 @@ cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class Uniform15KPC(Dataset): class Uniform15KPC(Dataset):
def __init__(self, root_dir, subdirs, tr_sample_size=10000, def __init__(
te_sample_size=10000, split='train', scale=1., self,
normalize_per_shape=False, box_per_shape=False, root_dir,
random_subsample=False, subdirs,
normalize_std_per_axis=False, tr_sample_size=10000,
all_points_mean=None, all_points_std=None, te_sample_size=10000,
input_dim=3, use_mask=False): split="train",
scale=1.0,
normalize_per_shape=False,
box_per_shape=False,
random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None,
all_points_std=None,
input_dim=3,
use_mask=False,
):
self.root_dir = root_dir self.root_dir = root_dir
self.split = split self.split = split
self.in_tr_sample_size = tr_sample_size self.in_tr_sample_size = tr_sample_size
@ -67,9 +111,9 @@ class Uniform15KPC(Dataset):
all_mids = [] all_mids = []
for x in os.listdir(sub_path): for x in os.listdir(sub_path):
if not x.endswith('.npy'): if not x.endswith(".npy"):
continue continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')])) all_mids.append(os.path.join(self.split, x[: -len(".npy")]))
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>" # NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids: for mid in all_mids:
@ -111,7 +155,9 @@ class Uniform15KPC(Dataset):
B, N = self.all_points.shape[:2] B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim) self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim) self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(
axis=1
).reshape(B, 1, input_dim)
else: # normalize across the dataset else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim) self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
@ -129,8 +175,7 @@ class Uniform15KPC(Dataset):
self.tr_sample_size = min(10000, tr_sample_size) self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size) self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points)) print("Total number of data:%d" % len(self.train_points))
print("Min number of points: (train)%d (test)%d" print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size))
% (self.tr_sample_size, self.te_sample_size))
assert self.scale == 1, "Scale (!= 1) is deprecated" assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx): def get_pc_stats(self, idx):
@ -139,7 +184,6 @@ class Uniform15KPC(Dataset):
s = self.all_points_std[idx].reshape(1, -1) s = self.all_points_std[idx].reshape(1, -1)
return m, s return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1) return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std): def renormalize(self, mean, std):
@ -173,11 +217,14 @@ class Uniform15KPC(Dataset):
sid, mid = self.all_cate_mids[idx] sid, mid = self.all_cate_mids[idx]
out = { out = {
'idx': idx, "idx": idx,
'train_points': tr_out, "train_points": tr_out,
'test_points': te_out, "test_points": te_out,
'mean': m, 'std': s, 'cate_idx': cate_idx, "mean": m,
'sid': sid, 'mid': mid "std": s,
"cate_idx": cate_idx,
"sid": sid,
"mid": mid,
} }
if self.use_mask: if self.use_mask:
@ -192,26 +239,35 @@ class Uniform15KPC(Dataset):
# out['train_points_masked'] = masked # out['train_points_masked'] = masked
# out['train_masks'] = tr_mask # out['train_masks'] = tr_mask
tr_mask = self.mask_transform(tr_out) tr_mask = self.mask_transform(tr_out)
out['train_masks'] = tr_mask out["train_masks"] = tr_mask
return out return out
class ShapeNet15kPointClouds(Uniform15KPC): class ShapeNet15kPointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k", def __init__(
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048, self,
split='train', scale=1., normalize_per_shape=False, root_dir="data/ShapeNetCore.v2.PC15k",
normalize_std_per_axis=False, box_per_shape=False, categories=["airplane"],
random_subsample=False, tr_sample_size=10000,
all_points_mean=None, all_points_std=None, te_sample_size=2048,
use_mask=False): split="train",
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
box_per_shape=False,
random_subsample=False,
all_points_mean=None,
all_points_std=None,
use_mask=False,
):
self.root_dir = root_dir self.root_dir = root_dir
self.split = split self.split = split
assert self.split in ['train', 'test', 'val'] assert self.split in ["train", "test", "val"]
self.tr_sample_size = tr_sample_size self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size self.te_sample_size = te_sample_size
self.cates = categories self.cates = categories
if 'all' in categories: if "all" in categories:
self.synset_ids = list(cate_to_synsetid.values()) self.synset_ids = list(cate_to_synsetid.values())
else: else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates] self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
@ -221,19 +277,21 @@ class ShapeNet15kPointClouds(Uniform15KPC):
self.display_axis_order = [0, 2, 1] self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__( super(ShapeNet15kPointClouds, self).__init__(
root_dir, self.synset_ids, root_dir,
self.synset_ids,
tr_sample_size=tr_sample_size, tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size, te_sample_size=te_sample_size,
split=split, scale=scale, split=split,
normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape, scale=scale,
normalize_per_shape=normalize_per_shape,
box_per_shape=box_per_shape,
normalize_std_per_axis=normalize_std_per_axis, normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample, random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std, all_points_mean=all_points_mean,
input_dim=3, use_mask=use_mask) all_points_std=all_points_std,
input_dim=3,
use_mask=use_mask,
)
#################################################################################### ####################################################################################

View file

@ -1,34 +1,70 @@
import hashlib
import os
import warnings import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
import os
import numpy as np
import hashlib
import torch
import matplotlib.pyplot as plt
synset_to_label = { synset_to_label = {
'02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', "02691156": "airplane",
'02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', "02773838": "bag",
'02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', "02801938": "basket",
'02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', "02808440": "bathtub",
'02954340': 'cap', '02958343': 'car', '03001627': 'chair', "02818832": "bed",
'03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', "02828884": "bench",
'04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', "02876657": "bottle",
'04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', "02880940": "bowl",
'03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', "02924116": "bus",
'03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', "02933112": "cabinet",
'03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', "02747177": "can",
'03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', "02942699": "camera",
'03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', "02954340": "cap",
'03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', "02958343": "car",
'03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', "03001627": "chair",
'04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', "03046257": "clock",
'04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', "03207941": "dishwasher",
'04554684': 'washer', '02992529': 'cellphone', "03211117": "monitor",
'02843684': 'birdhouse', '02871439': 'bookshelf', "04379243": "table",
"04401088": "telephone",
"02946921": "tin_can",
"04460130": "tower",
"04468005": "train",
"03085013": "keyboard",
"03261776": "earphone",
"03325088": "faucet",
"03337140": "file",
"03467517": "guitar",
"03513137": "helmet",
"03593526": "jar",
"03624134": "knife",
"03636649": "lamp",
"03642806": "laptop",
"03691459": "speaker",
"03710193": "mailbox",
"03759954": "microphone",
"03761084": "microwave",
"03790512": "motorcycle",
"03797390": "mug",
"03928116": "piano",
"03938244": "pillow",
"03948459": "pistol",
"03991062": "pot",
"04004475": "printer",
"04074963": "remote_control",
"04090263": "rifle",
"04099429": "rocket",
"04225987": "skateboard",
"04256520": "sofa",
"04330267": "stove",
"04530566": "vessel",
"04554684": "washer",
"02992529": "cellphone",
"02843684": "birdhouse",
"02871439": "bookshelf",
# '02858304': 'boat', no boat in our dataset, merged into vessels # '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy # '02834778': 'bicycle', not in our taxonomy
} }
@ -36,30 +72,44 @@ synset_to_label = {
# Label to Synset mapping (for ShapeNet core classes) # Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()} label_to_synset = {v: k for k, v in synset_to_label.items()}
def _convert_categories(categories): def _convert_categories(categories):
assert categories is not None, 'List of categories cannot be empty!' assert categories is not None, "List of categories cannot be empty!"
if not (c in synset_to_label.keys() + label_to_synset.keys() if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories):
for c in categories): warnings.warn(
warnings.warn('Some or all of the categories requested are not part of \ "Some or all of the categories requested are not part of \
ShapeNetCore. Data loading may fail if these categories are not avaliable.') ShapeNetCore. Data loading may fail if these categories are not avaliable."
synsets = [label_to_synset[c] if c in label_to_synset.keys() )
else c for c in categories] synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories]
return synsets return synsets
class ShapeNet_Multiview_Points(Dataset): class ShapeNet_Multiview_Points(Dataset):
def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val', def __init__(
npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False): self,
root_pc: str,
root_views: str,
cache: str,
categories: list = ["chair"],
split: str = "val",
npoints=2048,
sv_samples=800,
all_points_mean=None,
all_points_std=None,
get_image=False,
):
self.root = Path(root_views) self.root = Path(root_views)
self.split = split self.split = split
self.get_image = get_image self.get_image = get_image
params = { params = {
'cat': categories, "cat": categories,
'npoints': npoints, "npoints": npoints,
'sv_samples': sv_samples, "sv_samples": sv_samples,
} }
params = tuple(sorted(pair for pair in params.items())) params = tuple(sorted(pair for pair in params.items()))
self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest()) self.cache_dir = Path(cache) / "svpoints/{}/{}".format(
"_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest()
)
self.cache_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True)
self.paths = [] self.paths = []
@ -74,13 +124,12 @@ class ShapeNet_Multiview_Points(Dataset):
# loops through desired classes # loops through desired classes
for i in range(len(self.synsets)): for i in range(len(self.synsets)):
syn = self.synsets[i] syn = self.synsets[i]
class_target = self.root / syn class_target = self.root / syn
if not class_target.exists(): if not class_target.exists():
raise ValueError('Class {0} ({1}) was not found at location {2}.'.format( raise ValueError(
syn, self.labels[i], str(class_target))) "Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target))
)
sub_path_pc = os.path.join(root_pc, syn, split) sub_path_pc = os.path.join(root_pc, syn, split)
if not os.path.isdir(sub_path_pc): if not os.path.isdir(sub_path_pc):
@ -90,30 +139,30 @@ class ShapeNet_Multiview_Points(Dataset):
self.all_mids = [] self.all_mids = []
self.imgs = [] self.imgs = []
for x in os.listdir(sub_path_pc): for x in os.listdir(sub_path_pc):
if not x.endswith('.npy'): if not x.endswith(".npy"):
continue continue
self.all_mids.append(os.path.join(split, x[:-len('.npy')])) self.all_mids.append(os.path.join(split, x[: -len(".npy")]))
for mid in tqdm(self.all_mids): for mid in tqdm(self.all_mids):
# obj_fname = os.path.join(sub_path, x) # obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_pc, syn, mid + ".npy") obj_fname = os.path.join(root_pc, syn, mid + ".npy")
cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz')) cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz"))
if len(cams_pths) < 20: if len(cams_pths) < 20:
continue continue
point_cloud = np.load(obj_fname) point_cloud = np.load(obj_fname)
sv_points_group = [] sv_points_group = []
img_path_group = [] img_path_group = []
(self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True) (self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True)
success = True success = True
for i, cp in enumerate(cams_pths): for i, cp in enumerate(cams_pths):
cp = str(cp) cp = str(cp)
vp = cp.split('cam_params')[0] + 'depth.png' vp = cp.split("cam_params")[0] + "depth.png"
depth_minmax_pth = cp.split('_cam_params')[0] + '.npy' depth_minmax_pth = cp.split("_cam_params")[0] + ".npy"
cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) ) cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth))
cam_params = np.load(cp) cam_params = np.load(cp)
extr = cam_params['extr'] extr = cam_params["extr"]
intr = cam_params['intr'] intr = cam_params["intr"]
self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr) self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)
@ -125,7 +174,7 @@ class ShapeNet_Multiview_Points(Dataset):
sv_points_group.append(sv_point_cloud) sv_points_group.append(sv_point_cloud)
except Exception as e: except Exception as e:
print(e) print(e)
success=False success = False
break break
if not success: if not success:
continue continue
@ -144,64 +193,66 @@ class ShapeNet_Multiview_Points(Dataset):
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1) self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:,:10000] self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:,10000:] self.test_points = self.all_points[:, 10000:]
self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std
def get_pc_stats(self, idx): def get_pc_stats(self, idx):
return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1)
return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)
def __len__(self): def __len__(self):
"""Returns the length of the dataset. """ """Returns the length of the dataset."""
return len(self.all_points) return len(self.all_points)
def __getitem__(self, index): def __getitem__(self, index):
tr_out = self.train_points[index] tr_out = self.train_points[index]
tr_idxs = np.random.choice(tr_out.shape[0], self.npoints) tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
tr_out = tr_out[tr_idxs, :] tr_out = tr_out[tr_idxs, :]
gt_points = self.test_points[index][:self.npoints] gt_points = self.test_points[index][: self.npoints]
m, s = self.get_pc_stats(index) m, s = self.get_pc_stats(index)
sv_points = self.all_points_sv[index] sv_points = self.all_points_sv[index]
idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False) idxs = np.arange(0, sv_points.shape[-2])[
: self.sv_samples
] # np.random.choice(sv_points.shape[0], 500, replace=False)
data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(), data = torch.cat(
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1) [
torch.from_numpy(sv_points[:, idxs]).float(),
torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]),
],
dim=1,
)
masks = torch.zeros_like(data) masks = torch.zeros_like(data)
masks[:,:idxs.shape[0]] = 1 masks[:, : idxs.shape[0]] = 1
res = {'train_points': torch.from_numpy(tr_out).float(), res = {
'test_points': torch.from_numpy(gt_points).float(), "train_points": torch.from_numpy(tr_out).float(),
'sv_points': data, "test_points": torch.from_numpy(gt_points).float(),
'masks': masks, "sv_points": data,
'std': s, 'mean': m, "masks": masks,
'idx': index, "std": s,
'name':self.all_mids[index] "mean": m,
} "idx": index,
"name": self.all_mids[index],
if self.split != 'train' and self.get_image: }
if self.split != "train" and self.get_image:
img_lst = [] img_lst = []
for n in range(self.all_points_sv.shape[1]): for n in range(self.all_points_sv.shape[1]):
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3]
img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]
img_lst.append(img) img_lst.append(img)
img = torch.stack(img_lst, dim=0) img = torch.stack(img_lst, dim=0)
res['image'] = img res["image"] = img
return res return res
def _render(self, cache_path, depth_pth, depth_minmax_pth): def _render(self, cache_path, depth_pth, depth_minmax_pth):
# if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path): # if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
# #
@ -210,11 +261,9 @@ class ShapeNet_Multiview_Points(Dataset):
if os.path.exists(cache_path): if os.path.exists(cache_path):
data = np.load(cache_path) data = np.load(cache_path)
else: else:
data, depth = self.transform(depth_pth, depth_minmax_pth) data, depth = self.transform(depth_pth, depth_minmax_pth)
assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0]) assert data.shape[0] > 600, "Only {} points found".format(data.shape[0])
data = data[np.random.choice(data.shape[0], 600, replace=False)] data = data[np.random.choice(data.shape[0], 600, replace=False)]
np.save(cache_path, data) np.save(cache_path, data)
return data return data

View file

@ -1,25 +1,32 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib import importlib
import os import os
import torch
from torch import nn
from torch.autograd import Function
chamfer_found = importlib.find_loader("chamfer_2D") is not None chamfer_found = importlib.find_loader("chamfer_2D") is not None
if not chamfer_found: if not chamfer_found:
## Cool trick from https://github.com/chrdiller ## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 2D") print("Jitting Chamfer 2D")
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
chamfer_2D = load(name="chamfer_2D",
sources=[ chamfer_2D = load(
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), name="chamfer_2D",
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), sources=[
]) "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer2D.cu"]),
],
)
print("Loaded JIT 2D CUDA chamfer distance") print("Loaded JIT 2D CUDA chamfer distance")
else: else:
import chamfer_2D import chamfer_2D
print("Loaded compiled 2D CUDA chamfer distance") print("Loaded compiled 2D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix # Chamfer's distance module @thibaultgroueix
# GPU tensors only # GPU tensors only
class chamfer_2DFunction(Function): class chamfer_2DFunction(Function):
@ -57,9 +64,7 @@ class chamfer_2DFunction(Function):
gradxyz1 = gradxyz1.to(device) gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device) gradxyz2 = gradxyz2.to(device)
chamfer_2D.backward( chamfer_2D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2 return gradxyz1, gradxyz2

View file

@ -2,15 +2,16 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='chamfer_2D', name="chamfer_2D",
ext_modules=[ ext_modules=[
CUDAExtension('chamfer_2D', [ CUDAExtension(
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), "chamfer_2D",
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), [
]), "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(__file__.split("/")[:-1] + ["chamfer2D.cu"]),
],
),
], ],
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], cmdclass={"build_ext": BuildExtension},
cmdclass={ )
'build_ext': BuildExtension
})

View file

@ -1,25 +1,30 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib import importlib
import os import os
import torch
from torch import nn
from torch.autograd import Function
chamfer_found = importlib.find_loader("chamfer_3D") is not None chamfer_found = importlib.find_loader("chamfer_3D") is not None
if not chamfer_found: if not chamfer_found:
## Cool trick from https://github.com/chrdiller ## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 3D") print("Jitting Chamfer 3D")
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
chamfer_3D = load(name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],) chamfer_3D = load(
name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer3D.cu"]),
],
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
)
print("Loaded JIT 3D CUDA chamfer distance") print("Loaded JIT 3D CUDA chamfer distance")
else: else:
import chamfer_3D import chamfer_3D
print("Loaded compiled 3D CUDA chamfer distance") print("Loaded compiled 3D CUDA chamfer distance")
@ -60,9 +65,7 @@ class chamfer_3DFunction(Function):
gradxyz1 = gradxyz1.to(device) gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device) gradxyz2 = gradxyz2.to(device)
chamfer_3D.backward( chamfer_3D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2 return gradxyz1, gradxyz2
@ -74,4 +77,3 @@ class chamfer_3DDist(nn.Module):
input1 = input1.contiguous() input1 = input1.contiguous()
input2 = input2.contiguous() input2 = input2.contiguous()
return chamfer_3DFunction.apply(input1, input2) return chamfer_3DFunction.apply(input1, input2)

View file

@ -2,15 +2,16 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='chamfer_3D', name="chamfer_3D",
ext_modules=[ ext_modules=[
CUDAExtension('chamfer_3D', [ CUDAExtension(
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), "chamfer_3D",
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), [
]), "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(__file__.split("/")[:-1] + ["chamfer3D.cu"]),
],
),
], ],
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], cmdclass={"build_ext": BuildExtension},
cmdclass={ )
'build_ext': BuildExtension
})

View file

@ -1,24 +1,29 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib import importlib
import os import os
import torch
from torch import nn
from torch.autograd import Function
chamfer_found = importlib.find_loader("chamfer_5D") is not None chamfer_found = importlib.find_loader("chamfer_5D") is not None
if not chamfer_found: if not chamfer_found:
## Cool trick from https://github.com/chrdiller ## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 5D") print("Jitting Chamfer 5D")
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
chamfer_5D = load(name="chamfer_5D",
sources=[ chamfer_5D = load(
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), name="chamfer_5D",
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), sources=[
]) "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer5D.cu"]),
],
)
print("Loaded JIT 5D CUDA chamfer distance") print("Loaded JIT 5D CUDA chamfer distance")
else: else:
import chamfer_5D import chamfer_5D
print("Loaded compiled 5D CUDA chamfer distance") print("Loaded compiled 5D CUDA chamfer distance")
@ -59,9 +64,7 @@ class chamfer_5DFunction(Function):
gradxyz1 = gradxyz1.to(device) gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device) gradxyz2 = gradxyz2.to(device)
chamfer_5D.backward( chamfer_5D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2 return gradxyz1, gradxyz2

View file

@ -2,15 +2,16 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='chamfer_5D', name="chamfer_5D",
ext_modules=[ ext_modules=[
CUDAExtension('chamfer_5D', [ CUDAExtension(
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), "chamfer_5D",
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), [
]), "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]),
"/".join(__file__.split("/")[:-1] + ["chamfer5D.cu"]),
],
),
], ],
extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], cmdclass={"build_ext": BuildExtension},
cmdclass={ )
'build_ext': BuildExtension
})

View file

@ -33,8 +33,7 @@ def distChamfer(a, b):
xx = torch.pow(x, 2).sum(2) xx = torch.pow(x, 2).sum(2)
yy = torch.pow(y, 2).sum(2) yy = torch.pow(y, 2).sum(2)
zz = torch.bmm(x, y.transpose(2, 1)) zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
P = rx.transpose(2, 1) + ry - 2 * zz P = rx.transpose(2, 1) + ry - 2 * zz
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()

View file

@ -1,5 +1,6 @@
import torch import torch
def fscore(dist1, dist2, threshold=0.001): def fscore(dist1, dist2, threshold=0.001):
""" """
Calculates the F-score between two point clouds with the corresponding threshold value. Calculates the F-score between two point clouds with the corresponding threshold value.
@ -14,4 +15,3 @@ def fscore(dist1, dist2, threshold=0.001):
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
fscore[torch.isnan(fscore)] = 0 fscore[torch.isnan(fscore)] = 0
return fscore, precision_1, precision_2 return fscore, precision_1, precision_2

View file

@ -1,20 +1,23 @@
import torch, time import time
import chamfer2D.dist_chamfer_2D import chamfer2D.dist_chamfer_2D
import chamfer3D.dist_chamfer_3D import chamfer3D.dist_chamfer_3D
import chamfer5D.dist_chamfer_5D import chamfer5D.dist_chamfer_5D
import chamfer_python import chamfer_python
import torch
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
from torch.autograd import Variable
from fscore import fscore from fscore import fscore
from torch.autograd import Variable
def test_chamfer(distChamfer, dim): def test_chamfer(distChamfer, dim):
points1 = torch.rand(4, 100, dim).cuda() points1 = torch.rand(4, 100, dim).cuda()
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
dist1, dist2, idx1, idx2= distChamfer(points1, points2) dist1, dist2, idx1, idx2 = distChamfer(points1, points2)
loss = torch.sum(dist1) loss = torch.sum(dist1)
loss.backward() loss.backward()
@ -29,9 +32,9 @@ def test_chamfer(distChamfer, dim):
xd1 = idx1 - myidx1 xd1 = idx1 - myidx1
xd2 = idx2 - myidx2 xd2 = idx2 - myidx2
assert ( assert (
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
), "chamfer cuda and chamfer normal are not giving the same results" ), "chamfer cuda and chamfer normal are not giving the same results"
print(f"fscore :", fscore(dist1, dist2)) print("fscore :", fscore(dist1, dist2))
print("Unit test passed") print("Unit test passed")
@ -49,7 +52,6 @@ def timings(distChamfer, dim):
loss.backward() loss.backward()
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
print("Timings : Start Pythonic version") print("Timings : Start Pythonic version")
start = time.time() start = time.time()
for i in range(num_it): for i in range(num_it):
@ -61,9 +63,8 @@ def timings(distChamfer, dim):
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
dims = [2, 3, 5]
dims = [2,3,5] for i, cham in enumerate([cham2D, cham3D, cham5D]):
for i,cham in enumerate([cham2D, cham3D, cham5D]):
print(f"testing Chamfer {dims[i]}D") print(f"testing Chamfer {dims[i]}D")
test_chamfer(cham, dims[i]) test_chamfer(cham, dims[i])
timings(cham, dims[i]) timings(cham, dims[i])

View file

@ -1,5 +1,5 @@
import torch
import emd_cuda import emd_cuda
import torch
class EarthMoverDistanceFunction(torch.autograd.Function): class EarthMoverDistanceFunction(torch.autograd.Function):
@ -44,4 +44,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True):
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
cost = cost / xyz1.shape[1] cost = cost / xyz1.shape[1]
return cost return cost

View file

@ -9,19 +9,17 @@ Notes:
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='emd_ext', name="emd_ext",
ext_modules=[ ext_modules=[
CUDAExtension( CUDAExtension(
name='emd_cuda', name="emd_cuda",
sources=[ sources=[
'cuda/emd.cpp', "cuda/emd.cpp",
'cuda/emd_kernel.cu', "cuda/emd_kernel.cu",
], ],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
), ),
], ],
cmdclass={ cmdclass={"build_ext": BuildExtension},
'build_ext': BuildExtension )
})

View file

@ -1,6 +1,5 @@
import torch
import numpy as np import numpy as np
import time import torch
from emd import earth_mover_distance from emd import earth_mover_distance
# gt # gt
@ -13,10 +12,12 @@ print(p2)
p1.requires_grad = True p1.requires_grad = True
p2.requires_grad = True p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \ gt_dist = (
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \ (((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3 + (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2
print('gt_dist: ', gt_dist) + (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3
)
print("gt_dist: ", gt_dist)
gt_dist.backward() gt_dist.backward()
print(p1.grad) print(p1.grad)
@ -41,4 +42,3 @@ print(loss)
loss.backward() loss.backward()
print(p1.grad) print(p1.grad)
print(p2.grad) print(p2.grad)

View file

@ -1,17 +1,19 @@
import torch
import numpy as np
import warnings import warnings
import numpy as np
import torch
from numpy.linalg import norm
from scipy.stats import entropy from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
from metrics.ChamferDistancePytorch.fscore import fscore
from tqdm import tqdm from tqdm import tqdm
from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist
from metrics.ChamferDistancePytorch.fscore import fscore
from metrics.PyTorchEMD.emd import earth_mover_distance as EMD
cham3D = chamfer_3DDist() cham3D = chamfer_3DDist()
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet # Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b): def distChamfer(a, b):
x, y = a, b x, y = a, b
@ -22,11 +24,11 @@ def distChamfer(a, b):
diag_ind = torch.arange(0, num_points).to(a).long() diag_ind = torch.arange(0, num_points).to(a).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz) P = rx.transpose(2, 1) + ry - 2 * zz
return P.min(1)[0], P.min(2)[0] return P.min(1)[0], P.min(2)[0]
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
N_sample = sample_pcs.shape[0] N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0] N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
@ -56,13 +58,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
cd = torch.cat(cd_lst) cd = torch.cat(cd_lst)
emd = torch.cat(emd_lst) emd = torch.cat(emd_lst)
fs_lst = torch.cat(fs_lst).mean() fs_lst = torch.cat(fs_lst).mean()
results = { results = {"MMD-CD": cd, "MMD-EMD": emd, "fscore": fs_lst}
'MMD-CD': cd,
'MMD-EMD': emd,
'fscore': fs_lst
}
return results return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True): def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True):
N_sample = sample_pcs.shape[0] N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0] N_ref = ref_pcs.shape[0]
@ -107,7 +106,7 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0) M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
if sqrt: if sqrt:
M = M.abs().sqrt() M = M.abs().sqrt()
INFINITY = float('inf') INFINITY = float("inf")
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False) val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
count = torch.zeros(n0 + n1).to(Mxx) count = torch.zeros(n0 + n1).to(Mxx)
@ -116,19 +115,21 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False):
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
s = { s = {
'tp': (pred * label).sum(), "tp": (pred * label).sum(),
'fp': (pred * (1 - label)).sum(), "fp": (pred * (1 - label)).sum(),
'fn': ((1 - pred) * label).sum(), "fn": ((1 - pred) * label).sum(),
'tn': ((1 - pred) * (1 - label)).sum(), "tn": ((1 - pred) * (1 - label)).sum(),
} }
s.update({ s.update(
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), {
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), "precision": s["tp"] / (s["tp"] + s["fp"] + 1e-10),
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), "recall": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), "acc_t": s["tp"] / (s["tp"] + s["fn"] + 1e-10),
'acc': torch.eq(label, pred).float().mean(), "acc_f": s["tn"] / (s["tn"] + s["fp"] + 1e-10),
}) "acc": torch.eq(label, pred).float().mean(),
}
)
return s return s
@ -141,9 +142,9 @@ def lgan_mmd_cov(all_dist):
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
cov = torch.tensor(cov).to(all_dist) cov = torch.tensor(cov).to(all_dist)
return { return {
'lgan_mmd': mmd, "lgan_mmd": mmd,
'lgan_cov': cov, "lgan_cov": cov,
'lgan_mmd_smp': mmd_smp, "lgan_mmd_smp": mmd_smp,
} }
@ -153,27 +154,19 @@ def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
res_cd = lgan_mmd_cov(M_rs_cd.t()) res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({ results.update({"%s-CD" % k: v for k, v in res_cd.items()})
"%s-CD" % k: v for k, v in res_cd.items()
})
res_emd = lgan_mmd_cov(M_rs_emd.t()) res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({ results.update({"%s-EMD" % k: v for k, v in res_emd.items()})
"%s-EMD" % k: v for k, v in res_emd.items()
})
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
# 1-NN results # 1-NN results
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({ results.update({"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if "acc" in k})
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({ results.update({"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if "acc" in k})
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
return results return results
@ -227,11 +220,11 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
bound = 0.5 + epsilon bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
if verbose: if verbose:
warnings.warn('Point-clouds are not in unit cube.') warnings.warn("Point-clouds are not in unit cube.")
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
if verbose: if verbose:
warnings.warn('Point-clouds are not in unit sphere.') warnings.warn("Point-clouds are not in unit sphere.")
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3) grid_coordinates = grid_coordinates.reshape(-1, 3)
@ -260,9 +253,9 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose
def jensen_shannon_divergence(P, Q): def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0): if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.') raise ValueError("Negative values.")
if len(P) != len(Q): if len(P) != len(Q):
raise ValueError('Non equal size.') raise ValueError("Non equal size.")
P_ = P / np.sum(P) # Ensure probabilities. P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q) Q_ = Q / np.sum(Q)
@ -275,7 +268,7 @@ def jensen_shannon_divergence(P, Q):
res2 = _jsdiv(P_, Q_) res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0): if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.') warnings.warn("Numerical values of two JSD methods don't agree.")
return res return res
@ -312,11 +305,9 @@ if __name__ == "__main__":
r_dist = min_r.mean().cpu().detach().item() r_dist = min_r.mean().cpu().detach().item()
print(l_dist, r_dist) print(l_dist, r_dist)
emd_batch = EMD(x.cuda(), y.cuda(), False) emd_batch = EMD(x.cuda(), y.cuda(), False)
print(emd_batch.shape) print(emd_batch.shape)
print(emd_batch.mean().detach().item()) print(emd_batch.mean().detach().item())
jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy()) jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy())
print(jsd) print(jsd)

View file

@ -1,13 +1,14 @@
import functools import functools
import torch.nn as nn
import torch
import numpy as np import numpy as np
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish import torch
import torch.nn as nn
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
def _linear_gn_relu(in_channels, out_channels): def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, def create_pointnet_components(
width_multiplier=1, voxel_resolution_multiplier=1): blocks,
in_channels,
embed_dim,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0 layers, concat_channels = [], 0
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
if voxel_resolution is None: if voxel_resolution is None:
block = SharedMLP block = SharedMLP
else: else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, block = functools.partial(
with_se=with_se, normalize=normalize, eps=eps) PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
with_se=with_se,
normalize=normalize,
eps=eps,
)
if c == 0: if c == 0:
layers.append(block(in_channels, out_channels)) layers.append(block(in_channels, out_channels))
else: else:
layers.append(block(in_channels+embed_dim, out_channels)) layers.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels in_channels = out_channels
concat_channels += out_channels concat_channels += out_channels
c += 1 c += 1
return layers, in_channels, concat_channels return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False, def create_pointnet2_sa_components(
dropout=0.1, with_se=False, normalize=True, eps=0, sa_blocks,
width_multiplier=1, voxel_resolution_multiplier=1): extra_feature_channels,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3 in_channels = extra_feature_channels + 3
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
out_channels, num_blocks, voxel_resolution = conv_configs out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels) out_channels = int(r * out_channels)
for p in range(num_blocks): for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0 attention = (c + 1) % 2 == 0 and c > 0 and use_att and p == 0
if voxel_resolution is None: if voxel_resolution is None:
block = SharedMLP block = SharedMLP
else: else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, block = functools.partial(
dropout=dropout, PVConv,
with_se=with_se and not attention, with_se_relu=True, kernel_size=3,
normalize=normalize, eps=eps) resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se and not attention,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
if c == 0: if c == 0:
sa_blocks.append(block(in_channels, out_channels)) sa_blocks.append(block(in_channels, out_channels))
elif k ==0: elif k == 0:
sa_blocks.append(block(in_channels+embed_dim, out_channels)) sa_blocks.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels in_channels = out_channels
k += 1 k += 1
extra_feature_channels = in_channels extra_feature_channels = in_channels
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
if num_centers is None: if num_centers is None:
block = PointNetAModule block = PointNetAModule
else: else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, block = functools.partial(
num_neighbors=num_neighbors) PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, )
include_coordinates=True)) sa_blocks.append(
block(
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
out_channels=out_channels,
include_coordinates=True,
)
)
c += 1 c += 1
in_channels = extra_feature_channels = sa_blocks[-1].out_channels in_channels = extra_feature_channels = sa_blocks[-1].out_channels
if len(sa_blocks) == 1: if len(sa_blocks) == 1:
@ -127,10 +165,20 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False, def create_pointnet2_fp_modules(
dropout=0.1, fp_blocks,
with_se=False, normalize=True, eps=0, in_channels,
width_multiplier=1, voxel_resolution_multiplier=1): sa_in_channels,
sv_points,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = [] fp_layers = []
@ -139,7 +187,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
fp_blocks = [] fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs) out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append( fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) PointNetFPModule(
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
)
) )
in_channels = out_channels[-1] in_channels = out_channels[-1]
@ -151,9 +201,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
if voxel_resolution is None: if voxel_resolution is None:
block = SharedMLP block = SharedMLP
else: else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, block = functools.partial(
dropout=dropout, PVConv,
with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps) kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se and not attention,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
fp_blocks.append(block(in_channels, out_channels)) fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels in_channels = out_channels
@ -168,9 +226,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point
class PVCNN2Base(nn.Module): class PVCNN2Base(nn.Module):
def __init__(
def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1, self,
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1): num_classes,
sv_points,
embed_dim,
use_att,
dropout=0.1,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__() super().__init__()
assert extra_feature_channels >= 0 assert extra_feature_channels >= 0
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -178,9 +244,14 @@ class PVCNN2Base(nn.Module):
self.in_channels = extra_feature_channels + 3 self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim, sa_blocks=self.sa_blocks,
use_att=use_att, dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
) )
self.sa_layers = nn.ModuleList(sa_layers) self.sa_layers = nn.ModuleList(sa_layers)
@ -189,16 +260,26 @@ class PVCNN2Base(nn.Module):
# only use extra features in the last fp module # only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules( fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points, fp_blocks=self.fp_blocks,
with_se=True, embed_dim=embed_dim, in_channels=channels_sa_features,
use_att=use_att, dropout=dropout, sa_in_channels=sa_in_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier sv_points=sv_points,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
) )
self.fp_layers = nn.ModuleList(fp_layers) self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes], in_channels=channels_fp_features,
classifier=True, dim=2, width_multiplier=width_multiplier) out_channels=[128, 0.5, num_classes],
classifier=True,
dim=2,
width_multiplier=width_multiplier,
)
self.classifier = nn.Sequential(*layers) self.classifier = nn.Sequential(*layers)
self.embedf = nn.Sequential( self.embedf = nn.Sequential(
@ -223,31 +304,30 @@ class PVCNN2Base(nn.Module):
return emb return emb
def forward(self, inputs, t): def forward(self, inputs, t):
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
# inputs : [B, in_channels + S, N] # inputs : [B, in_channels + S, N]
coords, features = inputs[:, :3, :].contiguous(), inputs coords, features = inputs[:, :3, :].contiguous(), inputs
coords_list, in_features_list = [], [] coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers): for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features) in_features_list.append(features)
coords_list.append(coords) coords_list.append(coords)
if i == 0: if i == 0:
features, coords, temb = sa_blocks ((features, coords, temb)) features, coords, temb = sa_blocks((features, coords, temb))
else: else:
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb)) features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
in_features_list[0] = inputs[:, 3:, :].contiguous() in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None: if self.global_att is not None:
features = self.global_att(features) features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers): for fp_idx, fp_blocks in enumerate(self.fp_layers):
jump_coords = coords_list[-1 - fp_idx] jump_coords = coords_list[-1 - fp_idx]
fump_feats = in_features_list[-1-fp_idx] fump_feats = in_features_list[-1 - fp_idx]
# if fp_idx == len(self.fp_layers) - 1: # if fp_idx == len(self.fp_layers) - 1:
# jump_coords = jump_coords[:,:,self.sv_points:] # jump_coords = jump_coords[:,:,self.sv_points:]
# fump_feats = fump_feats[:,:,self.sv_points:] # fump_feats = fump_feats[:,:,self.sv_points:]
features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb)) features, coords, temb = fp_blocks(
(jump_coords, coords, torch.cat([features, temb], dim=1), fump_feats, temb)
)
return self.classifier(features) return self.classifier(features)

View file

@ -1,13 +1,14 @@
import functools import functools
import torch.nn as nn
import torch
import numpy as np import numpy as np
from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish import torch
import torch.nn as nn
from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish
def _linear_gn_relu(in_channels, out_channels): def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi
return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, def create_pointnet_components(
width_multiplier=1, voxel_resolution_multiplier=1): blocks,
in_channels,
embed_dim,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0 layers, concat_channels = [], 0
@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no
if voxel_resolution is None: if voxel_resolution is None:
block = SharedMLP block = SharedMLP
else: else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, block = functools.partial(
with_se=with_se, normalize=normalize, eps=eps) PVConv,
kernel_size=3,
resolution=int(vr * voxel_resolution),
attention=attention,
with_se=with_se,
normalize=normalize,
eps=eps,
)
if c == 0: if c == 0:
layers.append(block(in_channels, out_channels)) layers.append(block(in_channels, out_channels))
else: else:
layers.append(block(in_channels+embed_dim, out_channels)) layers.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels in_channels = out_channels
concat_channels += out_channels concat_channels += out_channels
c += 1 c += 1
return layers, in_channels, concat_channels return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False, def create_pointnet2_sa_components(
dropout=0.1, with_se=False, normalize=True, eps=0, sa_blocks,
width_multiplier=1, voxel_resolution_multiplier=1): extra_feature_channels,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3 in_channels = extra_feature_channels + 3
@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
out_channels, num_blocks, voxel_resolution = conv_configs out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels) out_channels = int(r * out_channels)
for p in range(num_blocks): for p in range(num_blocks):
attention = (c+1) % 2 == 0 and use_att and p == 0 attention = (c + 1) % 2 == 0 and use_att and p == 0
if voxel_resolution is None: if voxel_resolution is None:
block = SharedMLP block = SharedMLP
else: else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, block = functools.partial(
dropout=dropout, PVConv,
with_se=with_se, with_se_relu=True, kernel_size=3,
normalize=normalize, eps=eps) resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
if c == 0: if c == 0:
sa_blocks.append(block(in_channels, out_channels)) sa_blocks.append(block(in_channels, out_channels))
elif k ==0: elif k == 0:
sa_blocks.append(block(in_channels+embed_dim, out_channels)) sa_blocks.append(block(in_channels + embed_dim, out_channels))
in_channels = out_channels in_channels = out_channels
k += 1 k += 1
extra_feature_channels = in_channels extra_feature_channels = in_channels
@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
if num_centers is None: if num_centers is None:
block = PointNetAModule block = PointNetAModule
else: else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, block = functools.partial(
num_neighbors=num_neighbors) PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, )
include_coordinates=True)) sa_blocks.append(
block(
in_channels=extra_feature_channels + (embed_dim if k == 0 else 0),
out_channels=out_channels,
include_coordinates=True,
)
)
c += 1 c += 1
in_channels = extra_feature_channels = sa_blocks[-1].out_channels in_channels = extra_feature_channels = sa_blocks[-1].out_channels
if len(sa_blocks) == 1: if len(sa_blocks) == 1:
@ -127,10 +165,19 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, def create_pointnet2_fp_modules(
dropout=0.1, fp_blocks,
with_se=False, normalize=True, eps=0, in_channels,
width_multiplier=1, voxel_resolution_multiplier=1): sa_in_channels,
embed_dim=64,
use_att=False,
dropout=0.1,
with_se=False,
normalize=True,
eps=0,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
r, vr = width_multiplier, voxel_resolution_multiplier r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = [] fp_layers = []
@ -139,7 +186,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
fp_blocks = [] fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs) out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append( fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) PointNetFPModule(
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels
)
) )
in_channels = out_channels[-1] in_channels = out_channels[-1]
@ -147,14 +196,21 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
out_channels, num_blocks, voxel_resolution = conv_configs out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels) out_channels = int(r * out_channels)
for p in range(num_blocks): for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 attention = (c + 1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None: if voxel_resolution is None:
block = SharedMLP block = SharedMLP
else: else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, block = functools.partial(
dropout=dropout, PVConv,
with_se=with_se, with_se_relu=True, kernel_size=3,
normalize=normalize, eps=eps) resolution=int(vr * voxel_resolution),
attention=attention,
dropout=dropout,
with_se=with_se,
with_se_relu=True,
normalize=normalize,
eps=eps,
)
fp_blocks.append(block(in_channels, out_channels)) fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels in_channels = out_channels
@ -168,20 +224,31 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di
return fp_layers, in_channels return fp_layers, in_channels
class PVCNN2Base(nn.Module): class PVCNN2Base(nn.Module):
def __init__(
def __init__(self, num_classes, embed_dim, use_att, dropout=0.1, self,
extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1): num_classes,
embed_dim,
use_att,
dropout=0.1,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__() super().__init__()
assert extra_feature_channels >= 0 assert extra_feature_channels >= 0
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.in_channels = extra_feature_channels + 3 self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim, sa_blocks=self.sa_blocks,
use_att=use_att, dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
) )
self.sa_layers = nn.ModuleList(sa_layers) self.sa_layers = nn.ModuleList(sa_layers)
@ -190,15 +257,25 @@ class PVCNN2Base(nn.Module):
# only use extra features in the last fp module # only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules( fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim, fp_blocks=self.fp_blocks,
use_att=use_att, dropout=dropout, in_channels=channels_sa_features,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier sa_in_channels=sa_in_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
) )
self.fp_layers = nn.ModuleList(fp_layers) self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(
layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5 in_channels=channels_fp_features,
classifier=True, dim=2, width_multiplier=width_multiplier) out_channels=[128, dropout, num_classes], # was 0.5
classifier=True,
dim=2,
width_multiplier=width_multiplier,
)
self.classifier = nn.Sequential(*layers) self.classifier = nn.Sequential(*layers)
self.embedf = nn.Sequential( self.embedf = nn.Sequential(
@ -223,25 +300,30 @@ class PVCNN2Base(nn.Module):
return emb return emb
def forward(self, inputs, t): def forward(self, inputs, t):
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1])
temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1])
# inputs : [B, in_channels + S, N] # inputs : [B, in_channels + S, N]
coords, features = inputs[:, :3, :].contiguous(), inputs coords, features = inputs[:, :3, :].contiguous(), inputs
coords_list, in_features_list = [], [] coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers): for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features) in_features_list.append(features)
coords_list.append(coords) coords_list.append(coords)
if i == 0: if i == 0:
features, coords, temb = sa_blocks ((features, coords, temb)) features, coords, temb = sa_blocks((features, coords, temb))
else: else:
features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb)) features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb))
in_features_list[0] = inputs[:, 3:, :].contiguous() in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None: if self.global_att is not None:
features = self.global_att(features) features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers): for fp_idx, fp_blocks in enumerate(self.fp_layers):
features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb)) features, coords, temb = fp_blocks(
(
coords_list[-1 - fp_idx],
coords,
torch.cat([features, temb], dim=1),
in_features_list[-1 - fp_idx],
temb,
)
)
return self.classifier(features) return self.classifier(features)

View file

@ -1,8 +0,0 @@
from modules.ball_query import BallQuery
from modules.frustum import FrustumPointNetLoss
from modules.loss import KLLoss
from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
from modules.pvconv import PVConv, Attention, Swish, PVConvReLU
from modules.se import SE3d
from modules.shared_mlp import SharedMLP
from modules.voxelization import Voxelization

View file

@ -3,7 +3,7 @@ import torch.nn as nn
import modules.functional as F import modules.functional as F
__all__ = ['BallQuery'] __all__ = ["BallQuery"]
class BallQuery(nn.Module): class BallQuery(nn.Module):
@ -21,7 +21,7 @@ class BallQuery(nn.Module):
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
if points_features is None: if points_features is None:
assert self.include_coordinates, 'No Features For Grouping' assert self.include_coordinates, "No Features For Grouping"
neighbor_features = neighbor_coordinates neighbor_features = neighbor_coordinates
else: else:
neighbor_features = F.grouping(points_features, neighbor_indices) neighbor_features = F.grouping(points_features, neighbor_indices)
@ -30,5 +30,6 @@ class BallQuery(nn.Module):
return neighbor_features, F.grouping(temb, neighbor_indices) return neighbor_features, F.grouping(temb, neighbor_indices)
def extra_repr(self): def extra_repr(self):
return 'radius={}, num_neighbors={}{}'.format( return "radius={}, num_neighbors={}{}".format(
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') self.radius, self.num_neighbors, ", include coordinates" if self.include_coordinates else ""
)

View file

@ -5,12 +5,20 @@ import torch.nn.functional as F
import modules.functional as PF import modules.functional as PF
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d'] __all__ = ["FrustumPointNetLoss", "get_box_corners_3d"]
class FrustumPointNetLoss(nn.Module): class FrustumPointNetLoss(nn.Module):
def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0, def __init__(
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0): self,
num_heading_angle_bins,
num_size_templates,
size_templates,
box_loss_weight=1.0,
corners_loss_weight=10.0,
heading_residual_loss_weight=20.0,
size_residual_loss_weight=20.0,
):
super().__init__() super().__init__()
self.box_loss_weight = box_loss_weight self.box_loss_weight = box_loss_weight
self.corners_loss_weight = corners_loss_weight self.corners_loss_weight = corners_loss_weight
@ -19,28 +27,28 @@ class FrustumPointNetLoss(nn.Module):
self.num_heading_angle_bins = num_heading_angle_bins self.num_heading_angle_bins = num_heading_angle_bins
self.num_size_templates = num_size_templates self.num_size_templates = num_size_templates
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3)) self.register_buffer("size_templates", size_templates.view(self.num_size_templates, 3))
self.register_buffer( self.register_buffer(
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) "heading_angle_bin_centers", torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
) )
def forward(self, inputs, targets): def forward(self, inputs, targets):
mask_logits = inputs['mask_logits'] # (B, 2, N) mask_logits = inputs["mask_logits"] # (B, 2, N)
center_reg = inputs['center_reg'] # (B, 3) center_reg = inputs["center_reg"] # (B, 3)
center = inputs['center'] # (B, 3) center = inputs["center"] # (B, 3)
heading_scores = inputs['heading_scores'] # (B, NH) heading_scores = inputs["heading_scores"] # (B, NH)
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH) heading_residuals_normalized = inputs["heading_residuals_normalized"] # (B, NH)
heading_residuals = inputs['heading_residuals'] # (B, NH) heading_residuals = inputs["heading_residuals"] # (B, NH)
size_scores = inputs['size_scores'] # (B, NS) size_scores = inputs["size_scores"] # (B, NS)
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3) size_residuals_normalized = inputs["size_residuals_normalized"] # (B, NS, 3)
size_residuals = inputs['size_residuals'] # (B, NS, 3) size_residuals = inputs["size_residuals"] # (B, NS, 3)
mask_logits_target = targets['mask_logits'] # (B, N) mask_logits_target = targets["mask_logits"] # (B, N)
center_target = targets['center'] # (B, 3) center_target = targets["center"] # (B, 3)
heading_bin_id_target = targets['heading_bin_id'] # (B, ) heading_bin_id_target = targets["heading_bin_id"] # (B, )
heading_residual_target = targets['heading_residual'] # (B, ) heading_residual_target = targets["heading_residual"] # (B, )
size_template_id_target = targets['size_template_id'] # (B, ) size_template_id_target = targets["size_template_id"] # (B, )
size_residual_target = targets['size_residual'] # (B, 3) size_residual_target = targets["size_residual"] # (B, 3)
batch_size = center.size(0) batch_size = center.size(0)
batch_id = torch.arange(batch_size, device=center.device) batch_id = torch.arange(batch_size, device=center.device)
@ -65,25 +73,32 @@ class FrustumPointNetLoss(nn.Module):
) )
# Bounding box losses # Bounding box losses
heading = (heading_residuals[batch_id, heading_bin_id_target] heading = (
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, ) heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target]
) # (B, )
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets) # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
size = (size_residuals[batch_id, size_template_id_target] size = (
+ self.size_templates[size_template_id_target]) # (B, 3) size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target]
) # (B, 3)
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8) corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, ) heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3) size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target, corners_target, corners_target_flip = get_box_corners_3d(
sizes=size_target, with_flip=True) # (B, 3, 8) centers=center_target, headings=heading_target, sizes=size_target, with_flip=True
corners_loss = PF.huber_loss(torch.min( ) # (B, 3, 8)
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1) corners_loss = PF.huber_loss(
), delta=1.0) torch.min(torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)),
delta=1.0,
)
# Summing up # Summing up
loss = mask_loss + self.box_loss_weight * ( loss = mask_loss + self.box_loss_weight * (
center_loss + center_reg_loss + heading_loss + size_loss center_loss
+ self.heading_residual_loss_weight * heading_residual_normalized_loss + center_reg_loss
+ self.size_residual_loss_weight * size_residual_normalized_loss + heading_loss
+ self.corners_loss_weight * corners_loss + size_loss
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
+ self.size_residual_loss_weight * size_residual_normalized_loss
+ self.corners_loss_weight * corners_loss
) )
return loss return loss
@ -105,9 +120,9 @@ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
l = sizes[:, 0] # (N,) l = sizes[:, 0] # (N,)
w = sizes[:, 1] # (N,) w = sizes[:, 1] # (N,)
h = sizes[:, 2] # (N,) h = sizes[:, 2] # (N,)
x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8) x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8) y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8) z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
c = torch.cos(headings) # (N,) c = torch.cos(headings) # (N,)
s = torch.sin(headings) # (N,) s = torch.sin(headings) # (N,)

View file

@ -1,7 +0,0 @@
from modules.functional.ball_query import ball_query
from modules.functional.devoxelization import trilinear_devoxelize
from modules.functional.grouping import grouping
from modules.functional.interpolatation import nearest_neighbor_interpolate
from modules.functional.loss import kl_loss, huber_loss
from modules.functional.sampling import gather, furthest_point_sample, logits_mask
from modules.functional.voxelization import avg_voxelize

View file

@ -3,24 +3,28 @@ import os
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__)) _src_path = os.path.dirname(os.path.abspath(__file__))
_backend = load(name='_pvcnn_backend', _backend = load(
extra_cflags=['-O3', '-std=c++17'], name="_pvcnn_backend",
extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], extra_cflags=["-O3", "-std=c++17"],
sources=[os.path.join(_src_path,'src', f) for f in [ extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"],
'ball_query/ball_query.cpp', sources=[
'ball_query/ball_query.cu', os.path.join(_src_path, "src", f)
'grouping/grouping.cpp', for f in [
'grouping/grouping.cu', "ball_query/ball_query.cpp",
'interpolate/neighbor_interpolate.cpp', "ball_query/ball_query.cu",
'interpolate/neighbor_interpolate.cu', "grouping/grouping.cpp",
'interpolate/trilinear_devox.cpp', "grouping/grouping.cu",
'interpolate/trilinear_devox.cu', "interpolate/neighbor_interpolate.cpp",
'sampling/sampling.cpp', "interpolate/neighbor_interpolate.cu",
'sampling/sampling.cu', "interpolate/trilinear_devox.cpp",
'voxelization/vox.cpp', "interpolate/trilinear_devox.cu",
'voxelization/vox.cu', "sampling/sampling.cpp",
'bindings.cpp', "sampling/sampling.cu",
]] "voxelization/vox.cpp",
) "voxelization/vox.cu",
"bindings.cpp",
]
],
)
__all__ = ['_backend'] __all__ = ["_backend"]

View file

@ -1,19 +1,17 @@
from torch.autograd import Function
from modules.functional.backend import _backend from modules.functional.backend import _backend
__all__ = ['ball_query'] __all__ = ["ball_query"]
def ball_query(centers_coords, points_coords, radius, num_neighbors): def ball_query(centers_coords, points_coords, radius, num_neighbors):
""" """
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M] :param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param points_coords: coordinates of points, FloatTensor[B, 3, N] :param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param radius: float, radius of ball query :param radius: float, radius of ball query
:param num_neighbors: int, maximum number of neighbors :param num_neighbors: int, maximum number of neighbors
:return: :return:
neighbor_indices: indices of neighbors, IntTensor[B, M, U] neighbor_indices: indices of neighbors, IntTensor[B, M, U]
""" """
centers_coords = centers_coords.contiguous() centers_coords = centers_coords.contiguous()
points_coords = points_coords.contiguous() points_coords = points_coords.contiguous()
return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors) return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend from modules.functional.backend import _backend
__all__ = ['trilinear_devoxelize'] __all__ = ["trilinear_devoxelize"]
class TrilinearDevoxelization(Function): class TrilinearDevoxelization(Function):
@ -29,7 +29,7 @@ class TrilinearDevoxelization(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
""" """
:param ctx: :param ctx:
:param grad_output: gradient of outputs, FloatTensor[B, C, N] :param grad_output: gradient of outputs, FloatTensor[B, C, N]
:return: :return:
gradient of inputs, FloatTensor[B, C, R, R, R] gradient of inputs, FloatTensor[B, C, R, R, R]

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend from modules.functional.backend import _backend
__all__ = ['grouping'] __all__ = ["grouping"]
class Grouping(Function): class Grouping(Function):
@ -23,7 +23,7 @@ class Grouping(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points) grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None return grad_features, None

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend from modules.functional.backend import _backend
__all__ = ['nearest_neighbor_interpolate'] __all__ = ["nearest_neighbor_interpolate"]
class NeighborInterpolation(Function): class NeighborInterpolation(Function):

View file

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
__all__ = ['kl_loss', 'huber_loss'] __all__ = ["kl_loss", "huber_loss"]
def kl_loss(x, y): def kl_loss(x, y):
@ -13,5 +13,5 @@ def kl_loss(x, y):
def huber_loss(error, delta): def huber_loss(error, delta):
abs_error = torch.abs(error) abs_error = torch.abs(error)
quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta)) quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic) losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic)
return torch.mean(losses) return torch.mean(losses)

View file

@ -4,7 +4,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend from modules.functional.backend import _backend
__all__ = ['gather', 'furthest_point_sample', 'logits_mask'] __all__ = ["gather", "furthest_point_sample", "logits_mask"]
class Gather(Function): class Gather(Function):
@ -26,7 +26,7 @@ class Gather(Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points) grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None return grad_features, None
@ -60,11 +60,12 @@ def logits_mask(coords, logits, num_points_per_object):
mask: mask to select points, BoolTensor[B, N] mask: mask to select points, BoolTensor[B, N]
""" """
batch_size, _, num_points = coords.shape batch_size, _, num_points = coords.shape
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1] num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N] masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, masked_coords_mean = (
torch.ones_like(num_candidates)).float() # [B, C] torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, torch.ones_like(num_candidates)).float()
) # [B, C]
selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32) selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
for i in range(batch_size): for i in range(batch_size):
current_mask = mask[i] # [N] current_mask = mask[i] # [N]
@ -74,10 +75,14 @@ def logits_mask(coords, logits, num_points_per_object):
choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False) choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
selected_indices[i] = current_candidates[choices] selected_indices[i] = current_candidates[choices]
elif current_num_candidates > 0: elif current_num_candidates > 0:
choices = np.concatenate([ choices = np.concatenate(
np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates), [
np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False) np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
]) np.random.choice(
current_num_candidates, num_points_per_object % current_num_candidates, replace=False
),
]
)
np.random.shuffle(choices) np.random.shuffle(choices)
selected_indices[i] = current_candidates[choices] selected_indices[i] = current_candidates[choices]
selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices) selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)

View file

@ -2,7 +2,7 @@ from torch.autograd import Function
from modules.functional.backend import _backend from modules.functional.backend import _backend
__all__ = ['avg_voxelize'] __all__ = ["avg_voxelize"]
class AvgVoxelization(Function): class AvgVoxelization(Function):

View file

@ -2,7 +2,7 @@ import torch.nn as nn
import modules.functional as F import modules.functional as F
__all__ = ['KLLoss'] __all__ = ["KLLoss"]
class KLLoss(nn.Module): class KLLoss(nn.Module):

View file

@ -5,7 +5,7 @@ import modules.functional as F
from modules.ball_query import BallQuery from modules.ball_query import BallQuery
from modules.shared_mlp import SharedMLP from modules.shared_mlp import SharedMLP
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule'] __all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"]
class PointNetAModule(nn.Module): class PointNetAModule(nn.Module):
@ -20,8 +20,9 @@ class PointNetAModule(nn.Module):
total_out_channels = 0 total_out_channels = 0
for _out_channels in out_channels: for _out_channels in out_channels:
mlps.append( mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), SharedMLP(
out_channels=_out_channels, dim=1) in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1
)
) )
total_out_channels += _out_channels[-1] total_out_channels += _out_channels[-1]
@ -43,7 +44,7 @@ class PointNetAModule(nn.Module):
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
def extra_repr(self): def extra_repr(self):
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}"
class PointNetSAModule(nn.Module): class PointNetSAModule(nn.Module):
@ -67,8 +68,9 @@ class PointNetSAModule(nn.Module):
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
) )
mlps.append( mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), SharedMLP(
out_channels=_out_channels, dim=2) in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2
)
) )
total_out_channels += _out_channels[-1] total_out_channels += _out_channels[-1]
@ -90,7 +92,7 @@ class PointNetSAModule(nn.Module):
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
def extra_repr(self): def extra_repr(self):
return f'num_centers={self.num_centers}, out_channels={self.out_channels}' return f"num_centers={self.num_centers}, out_channels={self.out_channels}"
class PointNetFPModule(nn.Module): class PointNetFPModule(nn.Module):
@ -107,7 +109,5 @@ class PointNetFPModule(nn.Module):
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb) interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
if points_features is not None: if points_features is not None:
interpolated_features = torch.cat( interpolated_features = torch.cat([interpolated_features, points_features], dim=1)
[interpolated_features, points_features], dim=1
)
return self.mlp(interpolated_features), points_coords, interpolated_temb return self.mlp(interpolated_features), points_coords, interpolated_temb

View file

@ -1,16 +1,17 @@
import torch.nn as nn
import torch import torch
import modules.functional as F import torch.nn as nn
from modules.voxelization import Voxelization
from modules.shared_mlp import SharedMLP
from modules.se import SE3d
__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU'] import modules.functional as F
from modules.se import SE3d
from modules.shared_mlp import SharedMLP
from modules.voxelization import Voxelization
__all__ = ["PVConv", "Attention", "Swish", "PVConvReLU"]
class Swish(nn.Module): class Swish(nn.Module):
def forward(self,x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class Attention(nn.Module): class Attention(nn.Module):
@ -35,23 +36,19 @@ class Attention(nn.Module):
self.sm = nn.Softmax(-1) self.sm = nn.Softmax(-1)
def forward(self, x): def forward(self, x):
B, C = x.shape[:2] B, C = x.shape[:2]
h = x h = x
q = self.q(h).reshape(B, C, -1)
k = self.k(h).reshape(B, C, -1)
v = self.v(h).reshape(B, C, -1)
qk = torch.matmul(q.permute(0, 2, 1), k) # * (int(C) ** (-0.5))
q = self.q(h).reshape(B,C,-1)
k = self.k(h).reshape(B,C,-1)
v = self.v(h).reshape(B,C,-1)
qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5))
w = self.sm(qk) w = self.sm(qk)
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:]) h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B, C, *x.shape[2:])
h = self.out(h) h = self.out(h)
@ -61,9 +58,21 @@ class Attention(nn.Module):
return x return x
class PVConv(nn.Module): class PVConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, def __init__(
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): self,
in_channels,
out_channels,
kernel_size,
resolution,
attention=False,
dropout=0.1,
with_se=False,
with_se_relu=False,
normalize=True,
eps=0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -74,13 +83,13 @@ class PVConv(nn.Module):
voxel_layers = [ voxel_layers = [
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.GroupNorm(num_groups=8, num_channels=out_channels), nn.GroupNorm(num_groups=8, num_channels=out_channels),
Swish() Swish(),
] ]
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
voxel_layers += [ voxel_layers += [
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.GroupNorm(num_groups=8, num_channels=out_channels), nn.GroupNorm(num_groups=8, num_channels=out_channels),
Attention(out_channels, 8) if attention else Swish() Attention(out_channels, 8) if attention else Swish(),
] ]
if with_se: if with_se:
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
@ -96,10 +105,21 @@ class PVConv(nn.Module):
return fused_features, coords, temb return fused_features, coords, temb
class PVConvReLU(nn.Module): class PVConvReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2, def __init__(
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): self,
in_channels,
out_channels,
kernel_size,
resolution,
attention=False,
leak=0.2,
dropout=0.1,
with_se=False,
with_se_relu=False,
normalize=True,
eps=0,
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -110,13 +130,13 @@ class PVConvReLU(nn.Module):
voxel_layers = [ voxel_layers = [
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.BatchNorm3d(out_channels), nn.BatchNorm3d(out_channels),
nn.LeakyReLU(leak, True) nn.LeakyReLU(leak, True),
] ]
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
voxel_layers += [ voxel_layers += [
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
nn.BatchNorm3d(out_channels), nn.BatchNorm3d(out_channels),
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True) Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True),
] ]
if with_se: if with_se:
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))

View file

@ -1,18 +1,22 @@
import torch.nn as nn
import torch import torch
__all__ = ['SE3d'] import torch.nn as nn
__all__ = ["SE3d"]
class Swish(nn.Module): class Swish(nn.Module):
def forward(self,x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class SE3d(nn.Module): class SE3d(nn.Module):
def __init__(self, channel, reduction=8, use_relu=False): def __init__(self, channel, reduction=8, use_relu=False):
super().__init__() super().__init__()
self.fc = nn.Sequential( self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(True) if use_relu else Swish() , nn.ReLU(True) if use_relu else Swish(),
nn.Linear(channel // reduction, channel, bias=False), nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, inputs): def forward(self, inputs):

View file

@ -1,12 +1,13 @@
import torch.nn as nn
import torch import torch
import torch.nn as nn
__all__ = ['SharedMLP'] __all__ = ["SharedMLP"]
class Swish(nn.Module): class Swish(nn.Module):
def forward(self,x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class SharedMLP(nn.Module): class SharedMLP(nn.Module):
def __init__(self, in_channels, out_channels, dim=1): def __init__(self, in_channels, out_channels, dim=1):
@ -23,11 +24,13 @@ class SharedMLP(nn.Module):
out_channels = [out_channels] out_channels = [out_channels]
layers = [] layers = []
for oc in out_channels: for oc in out_channels:
layers.extend([ layers.extend(
conv(in_channels, oc, 1), [
bn(8, oc), conv(in_channels, oc, 1),
Swish(), bn(8, oc),
]) Swish(),
]
)
in_channels = oc in_channels = oc
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)

View file

@ -3,7 +3,7 @@ import torch.nn as nn
import modules.functional as F import modules.functional as F
__all__ = ['Voxelization'] __all__ = ["Voxelization"]
class Voxelization(nn.Module): class Voxelization(nn.Module):
@ -17,7 +17,10 @@ class Voxelization(nn.Module):
coords = coords.detach() coords = coords.detach()
norm_coords = coords - coords.mean(2, keepdim=True) norm_coords = coords - coords.mean(2, keepdim=True)
if self.normalize: if self.normalize:
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5 norm_coords = (
norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps)
+ 0.5
)
else: else:
norm_coords = (norm_coords + 1) / 2.0 norm_coords = (norm_coords + 1) / 2.0
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
@ -25,4 +28,4 @@ class Voxelization(nn.Module):
return F.avg_voxelize(features, vox_coords, self.r), norm_coords return F.avg_voxelize(features, vox_coords, self.r), norm_coords
def extra_repr(self): def extra_repr(self):
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '') return "resolution={}{}".format(self.r, ", normalized eps = {}".format(self.eps) if self.normalize else "")

View file

@ -1,26 +1,27 @@
import argparse
from pprint import pprint from pprint import pprint
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch.nn as nn import torch.nn as nn
import torch.utils.data import torch.utils.data
import argparse
from torch.distributions import Normal from torch.distributions import Normal
from utils.file_utils import *
from model.pvcnn_completion import PVCNN2Base
from datasets.shapenet_data_pc import ShapeNet15kPointClouds from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import * from datasets.shapenet_data_sv import *
''' from metrics.evaluation_metrics import EMD_CD, compute_all_metrics
from model.pvcnn_completion import PVCNN2Base
from utils.file_utils import *
"""
models models
''' """
def normal_kl(mean1, logvar1, mean2, logvar2): def normal_kl(mean1, logvar1, mean2, logvar2):
""" """
KL divergence between normal distributions parameterized by mean and log-variance. KL divergence between normal distributions parameterized by mean and log-variance.
""" """
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales): def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1] # Assumes data is integers [0, 1]
@ -31,21 +32,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
inv_stdv = torch.exp(-log_scales) inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5) plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in) cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5) min_in = inv_stdv * (centered_x - 0.5)
cdf_min = px0.cdf(min_in) cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
cdf_delta = cdf_plus - cdf_min cdf_delta = cdf_plus - cdf_min
log_probs = torch.where( log_probs = torch.where(
x < 0.001, log_cdf_plus, x < 0.001,
torch.where(x > 0.999, log_one_minus_cdf_min, log_cdf_plus,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) torch.where(
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
),
)
assert log_probs.shape == x.shape assert log_probs.shape == x.shape
return log_probs return log_probs
class GaussianDiffusion: class GaussianDiffusion:
def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points):
self.loss_type = loss_type self.loss_type = loss_type
@ -54,15 +57,15 @@ class GaussianDiffusion:
assert isinstance(betas, np.ndarray) assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all() assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape (timesteps,) = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
self.sv_points = sv_points self.sv_points = sv_points
# initialize twice the actual length so we can keep running for eval # initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas alphas = 1.0 - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float() self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float() self.alphas_cumprod = alphas_cumprod.float()
@ -70,21 +73,23 @@ class GaussianDiffusion:
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float() betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float() alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) self.posterior_log_variance_clipped = torch.log(
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) )
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
@staticmethod @staticmethod
def _extract(a, t, x_shape): def _extract(a, t, x_shape):
@ -92,17 +97,15 @@ class GaussianDiffusion:
Extract some coefficients at specified timesteps, Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
""" """
bs, = t.shape (bs,) = t.shape
assert x_shape[0] == bs assert x_shape[0] == bs
out = torch.gather(a, 0, t) out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs]) assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance return mean, variance, log_variance
@ -114,54 +117,59 @@ class GaussianDiffusion:
noise = torch.randn(x_start.shape, device=x_start.device) noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape assert noise.shape == x_start.shape
return ( return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, x_start, x_t, t):
""" """
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
""" """
assert x_start.shape == x_t.shape assert x_start.shape == x_t.shape
posterior_mean = ( posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
) )
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) posterior_log_variance_clipped = self._extract(
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
x_start.shape[0]) )
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t)[:, :, self.sv_points :]
model_output = denoise_fn(data, t)[:,:,self.sv_points:] if self.model_var_type in ["fixedsmall", "fixedlarge"]:
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations # below: only log_variance is used in the KL computations
model_variance, model_log_variance = { model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device), "fixedlarge": (
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), self.betas.to(data.device),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
),
"fixedsmall": (
self.posterior_variance.to(data.device),
self.posterior_log_variance_clipped.to(data.device),
),
}[self.model_var_type] }[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output)
else: else:
raise NotImplementedError(self.model_var_type) raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps': if self.model_mean_type == "eps":
x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t)
else: else:
raise NotImplementedError(self.loss_type) raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape assert model_mean.shape == x_recon.shape
assert model_variance.shape == model_log_variance.shape assert model_variance.shape == model_log_variance.shape
if return_pred_xstart: if return_pred_xstart:
@ -172,30 +180,31 @@ class GaussianDiffusion:
def _predict_xstart_from_eps(self, x_t, t, eps): def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape assert x_t.shape == eps.shape
return ( return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
) )
''' samples ''' """ samples """
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False):
""" """
Sample from the model Sample from the model
""" """
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
return_pred_xstart=True) denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1))
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1)
return (sample, pred_xstart) if return_pred_xstart else sample return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(
def p_sample_loop(self, partial_x, denoise_fn, shape, device, self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False
noise_fn=torch.randn, clip_denoised=True, keep_running=False): ):
""" """
Generate samples Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
@ -206,14 +215,21 @@ class GaussianDiffusion:
img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1)
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, img_t = self.p_sample(
clip_denoised=clip_denoised, return_pred_xstart=False) denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
)
assert img_t[:,:,self.sv_points:].shape == shape assert img_t[:, :, self.sv_points :].shape == shape
return img_t return img_t
def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, def p_sample_loop_trajectory(
noise_fn=torch.randn,clip_denoised=True, keep_running=False): self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False
):
""" """
Generate samples, returning intermediate images Generate samples, returning intermediate images
Useful for visualizing how denoised images evolve over time Useful for visualizing how denoised images evolve over time
@ -223,31 +239,38 @@ class GaussianDiffusion:
""" """
assert isinstance(shape, (tuple, list)) assert isinstance(shape, (tuple, list))
total_steps = self.num_timesteps if not keep_running else len(self.betas) total_steps = self.num_timesteps if not keep_running else len(self.betas)
img_t = noise_fn(size=shape, dtype=torch.float, device=device) img_t = noise_fn(size=shape, dtype=torch.float, device=device)
imgs = [img_t] imgs = [img_t]
for t in reversed(range(0,total_steps)): for t in reversed(range(0, total_steps)):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, img_t = self.p_sample(
clip_denoised=clip_denoised, denoise_fn=denoise_fn,
return_pred_xstart=False) data=img_t,
if t % freq == 0 or t == total_steps-1: t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
)
if t % freq == 0 or t == total_steps - 1:
imgs.append(img_t) imgs.append(img_t)
assert imgs[-1].shape == shape assert imgs[-1].shape == shape
return imgs return imgs
'''losses''' """losses"""
def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t
)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0)
return (kl, pred_xstart) if return_pred_xstart else kl return (kl, pred_xstart) if return_pred_xstart else kl
@ -259,66 +282,87 @@ class GaussianDiffusion:
assert t.shape == torch.Size([B]) assert t.shape == torch.Size([B])
if noise is None: if noise is None:
noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) noise = torch.randn(
data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device
)
data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise)
if self.loss_type == 'mse': if self.loss_type == "mse":
# predict the noise instead of x_start. seems to be weighted naturally like SNR # predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[
:, :, self.sv_points :
]
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl': elif self.loss_type == "kl":
losses = self._vb_terms_bpd( losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, denoise_fn=denoise_fn,
return_pred_xstart=False) data_start=data_start,
data_t=data_t,
t=t,
clip_denoised=False,
return_pred_xstart=False,
)
else: else:
raise NotImplementedError(self.loss_type) raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B]) assert losses.shape == torch.Size([B])
return losses return losses
'''debug''' """debug"""
def _prior_bpd(self, x_start): def _prior_bpd(self, x_start):
with torch.no_grad(): with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps B, T = x_start.shape[0], self.num_timesteps
t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, kl_prior = normal_kl(
mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) mean1=qt_mean,
logvar1=qt_log_variance,
mean2=torch.tensor([0.0]).to(qt_mean),
logvar2=torch.tensor([0.0]).to(qt_log_variance),
)
assert kl_prior.shape == x_start.shape assert kl_prior.shape == x_start.shape
return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0)
def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True):
with torch.no_grad(): with torch.no_grad():
B, T = x_start.shape[0], self.num_timesteps B, T = x_start.shape[0], self.num_timesteps
vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device)
for t in reversed(range(T)): for t in reversed(range(T)):
t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t)
# Calculate VLB term at the current timestep # Calculate VLB term at the current timestep
data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) data_t = torch.cat(
[x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)],
dim=-1,
)
new_vals_b, pred_xstart = self._vb_terms_bpd( new_vals_b, pred_xstart = self._vb_terms_bpd(
denoise_fn, data_start=x_start, data_t=data_t, t=t_b, denoise_fn,
clip_denoised=clip_denoised, return_pred_xstart=True) data_start=x_start,
data_t=data_t,
t=t_b,
clip_denoised=clip_denoised,
return_pred_xstart=True,
)
# MSE for progressive prediction loss # MSE for progressive prediction loss
assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape
new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean(
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) dim=list(range(1, len(pred_xstart.shape)))
)
assert new_vals_b.shape == new_mse_b.shape == torch.Size([B])
# Insert the calculated term into the tensor of all terms # Insert the calculated term into the tensor of all terms
mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float()
vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt
mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt
assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T])
prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :])
total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b
assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ assert vals_bt_.shape == mse_bt_.shape == torch.Size(
total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) [B, T]
) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B])
return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean()
@ -336,39 +380,53 @@ class PVCNN2(PVCNN2Base):
((128, 128, 64), (64, 2, 32)), ((128, 128, 64), (64, 2, 32)),
] ]
def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, def __init__(
voxel_resolution_multiplier=1): self,
num_classes,
sv_points,
embed_dim,
use_att,
dropout,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__( super().__init__(
num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, num_classes=num_classes,
dropout=dropout, extra_feature_channels=extra_feature_channels, sv_points=sv_points,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
) )
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
super(Model, self).__init__() super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints)
self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, self.model = PVCNN2(
dropout=args.dropout, extra_feature_channels=0) num_classes=args.nc,
sv_points=args.svpoints,
embed_dim=args.embed_dim,
use_att=args.attention,
dropout=args.dropout,
extra_feature_channels=0,
)
def prior_kl(self, x0): def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0) return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True): def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
def _denoise(self, data, t): def _denoise(self, data, t):
B, D,N= data.shape B, D, N = data.shape
assert data.dtype == torch.float assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64 assert t.shape == torch.Size([B]) and t.dtype == torch.int64
@ -381,20 +439,22 @@ class Model(nn.Module):
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None: if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses( losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B]) assert losses.shape == t.shape == torch.Size([B])
return losses return losses
def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False):
clip_denoised=True, return self.diffusion.p_sample_loop(
keep_running=False): partial_x,
return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, self._denoise,
clip_denoised=clip_denoised, shape=shape,
keep_running=keep_running) device=device,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running,
)
def train(self): def train(self):
self.model.train() self.model.train()
@ -405,21 +465,19 @@ class Model(nn.Module):
def multi_gpu_wrapper(self, f): def multi_gpu_wrapper(self, f):
self.model = f(self.model) self.model = f(self.model)
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear':
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1':
def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == "linear":
betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == "warm0.1":
betas = b_end * np.ones(time_num, dtype=np.float64) betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1) warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2': elif schedule_type == "warm0.2":
betas = b_end * np.ones(time_num, dtype=np.float64) betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2) warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5': elif schedule_type == "warm0.5":
betas = b_end * np.ones(time_num, dtype=np.float64) betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5) warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
@ -428,22 +486,29 @@ def get_betas(schedule_type, b_start, b_end, time_num):
return betas return betas
############################################################################# #############################################################################
def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, def get_mvr_dataset(pc_dataroot, views_root, npoints, category):
categories=[category], split='train', tr_dataset = ShapeNet15kPointClouds(
root_dir=pc_dataroot,
categories=[category],
split="train",
tr_sample_size=npoints, tr_sample_size=npoints,
te_sample_size=npoints, te_sample_size=npoints,
scale=1., scale=1.0,
normalize_per_shape=False, normalize_per_shape=False,
normalize_std_per_axis=False, normalize_std_per_axis=False,
random_subsample=True) random_subsample=True,
te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, )
cache=os.path.join(pc_dataroot, '../cache'), split='val', te_dataset = ShapeNet_Multiview_Points(
root_pc=pc_dataroot,
root_views=views_root,
cache=os.path.join(pc_dataroot, "../cache"),
split="val",
categories=[category], categories=[category],
npoints=npoints, sv_samples=200, npoints=npoints,
sv_samples=200,
all_points_mean=tr_dataset.all_points_mean, all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std, all_points_std=tr_dataset.all_points_std,
) )
@ -451,39 +516,41 @@ def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
def evaluate_recon_mvr(opt, model, save_dir, logger): def evaluate_recon_mvr(opt, model, save_dir, logger):
test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.category)
opt.npoints, opt.category)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, test_dataloader = torch.utils.data.DataLoader(
shuffle=False, num_workers=int(opt.workers), drop_last=False) test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
ref = [] ref = []
samples = [] samples = []
masked = [] masked = []
k = 0 for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Reconstructing Samples"):
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): gt_all = data["test_points"]
x_all = data["sv_points"]
gt_all = data['test_points'] B, V, N, C = x_all.shape
x_all = data['sv_points'] gt_all = gt_all[:, None, :, :].expand(-1, V, -1, -1)
B,V,N,C = x_all.shape
gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1)
x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous()
m, s = data['mean'].float(), data['std'].float() m, s = data["mean"].float(), data["std"].float()
recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', recon = (
clip_denoised=False).detach().cpu() model.gen_samples(
x[:, :, : opt.svpoints].cuda(), x[:, :, opt.svpoints :].shape, "cuda", clip_denoised=False
)
.detach()
.cpu()
)
recon = recon.transpose(1, 2).contiguous() recon = recon.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous() x = x.transpose(1, 2).contiguous()
x_adj = x.reshape(B, V, N, C) * s + m
recon_adj = recon.reshape(B, V, N, C) * s + m
x_adj = x.reshape(B,V,N,C)* s + m ref.append(gt_all * s + m)
recon_adj = recon.reshape(B,V,N,C)* s + m masked.append(x_adj[:, :, : test_dataloader.dataset.sv_samples, :])
ref.append( gt_all * s + m)
masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:])
samples.append(recon_adj) samples.append(recon_adj)
ref_pcs = torch.cat(ref, dim=0) ref_pcs = torch.cat(ref, dim=0)
@ -492,31 +559,40 @@ def evaluate_recon_mvr(opt, model, save_dir, logger):
B, V, N, C = ref_pcs.shape B, V, N, C = ref_pcs.shape
torch.save(ref_pcs.reshape(B, V, N, C), os.path.join(save_dir, "recon_gt.pth"))
torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) torch.save(masked.reshape(B, V, *masked.shape[2:]), os.path.join(save_dir, "recon_masked.pth"))
torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth'))
# Compute metrics # Compute metrics
results = EMD_CD(sample_pcs.reshape(B*V, N, C), results = EMD_CD(sample_pcs.reshape(B * V, N, C), ref_pcs.reshape(B * V, N, C), opt.batch_size, reduced=False)
ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False)
results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} results = {
ky: val.reshape(B, V)
if val.shape
== torch.Size(
[
B * V,
]
)
else val
for ky, val in results.items()
}
pprint({key: val.mean().item() for key, val in results.items()}) pprint({key: val.mean().item() for key, val in results.items()})
logger.info({key: val.mean().item() for key, val in results.items()}) logger.info({key: val.mean().item() for key, val in results.items()})
results['pc'] = sample_pcs results["pc"] = sample_pcs
torch.save(results, os.path.join(save_dir, 'ours_results.pth')) torch.save(results, os.path.join(save_dir, "ours_results.pth"))
del ref_pcs, masked, results del ref_pcs, masked, results
def evaluate_saved(opt, saved_dir): def evaluate_saved(opt, saved_dir):
# ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' # ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn'
gt_pth = saved_dir + '/recon_gt.pth' gt_pth = saved_dir + "/recon_gt.pth"
ours_pth = saved_dir + '/ours_results.pth' ours_pth = saved_dir + "/ours_results.pth"
gt = torch.load(gt_pth).permute(1,0,2,3) gt = torch.load(gt_pth).permute(1, 0, 2, 3)
ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) ours = torch.load(ours_pth)["pc"].permute(1, 0, 2, 3)
all_res = {} all_res = {}
for i, (gt_, ours_) in enumerate(zip(gt, ours)): for i, (gt_, ours_) in enumerate(zip(gt, ours)):
@ -534,7 +610,6 @@ def evaluate_saved(opt, saved_dir):
pprint({key: val.mean().item() for key, val in all_res.items()}) pprint({key: val.mean().item() for key, val in all_res.items()})
def main(opt): def main(opt):
exp_id = os.path.splitext(os.path.basename(__file__))[0] exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__) dir_id = os.path.dirname(__file__)
@ -542,7 +617,7 @@ def main(opt):
copy_source(__file__, output_dir) copy_source(__file__, output_dir)
logger = setup_logging(output_dir) logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn') (outf_syn,) = setup_output_subdirs(output_dir, "syn")
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
@ -559,12 +634,10 @@ def main(opt):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logger.info("Resume Path:%s" % opt.model) logger.info("Resume Path:%s" % opt.model)
resumed_param = torch.load(opt.model) resumed_param = torch.load(opt.model)
model.load_state_dict(resumed_param['model_state']) model.load_state_dict(resumed_param["model_state"])
if opt.eval_recon_mvr: if opt.eval_recon_mvr:
# Evaluate generation # Evaluate generation
@ -575,47 +648,44 @@ def main(opt):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/') parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/")
parser.add_argument('--dataroot_sv', default='GenReData/') parser.add_argument("--dataroot_sv", default="GenReData/")
parser.add_argument('--category', default='chair') parser.add_argument("--category", default="chair")
parser.add_argument('--batch_size', type=int, default=50, help='input batch size') parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
parser.add_argument('--workers', type=int, default=16, help='workers') parser.add_argument("--workers", type=int, default=16, help="workers")
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
parser.add_argument('--eval_recon_mvr', default=True) parser.add_argument("--eval_recon_mvr", default=True)
parser.add_argument('--eval_saved', default=True) parser.add_argument("--eval_saved", default=True)
parser.add_argument('--nc', default=3) parser.add_argument("--nc", default=3)
parser.add_argument('--npoints', default=2048) parser.add_argument("--npoints", default=2048)
parser.add_argument('--svpoints', default=200) parser.add_argument("--svpoints", default=200)
'''model''' """model"""
parser.add_argument('--beta_start', default=0.0001) parser.add_argument("--beta_start", default=0.0001)
parser.add_argument('--beta_end', default=0.02) parser.add_argument("--beta_end", default=0.02)
parser.add_argument('--schedule_type', default='linear') parser.add_argument("--schedule_type", default="linear")
parser.add_argument('--time_num', default=1000) parser.add_argument("--time_num", default=1000)
#params # params
parser.add_argument('--attention', default=True) parser.add_argument("--attention", default=True)
parser.add_argument('--dropout', default=0.1) parser.add_argument("--dropout", default=0.1)
parser.add_argument('--embed_dim', type=int, default=64) parser.add_argument("--embed_dim", type=int, default=64)
parser.add_argument('--loss_type', default='mse') parser.add_argument("--loss_type", default="mse")
parser.add_argument('--model_mean_type', default='eps') parser.add_argument("--model_mean_type", default="eps")
parser.add_argument('--model_var_type', default='fixedsmall') parser.add_argument("--model_var_type", default="fixedsmall")
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
parser.add_argument('--model', default='', required=True, help="path to model (to continue training)") """eval"""
'''eval''' parser.add_argument("--eval_path", default="")
parser.add_argument('--eval_path', parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
default='')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed') parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args() opt = parser.parse_args()
@ -625,7 +695,9 @@ def parse_args():
opt.cuda = False opt.cuda = False
return opt return opt
if __name__ == '__main__':
if __name__ == "__main__":
opt = parse_args() opt = parse_args()
main(opt) main(opt)

View file

@ -1,31 +1,30 @@
import torch import argparse
from pprint import pprint from pprint import pprint
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from metrics.evaluation_metrics import compute_all_metrics, EMD_CD
import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.data import torch.utils.data
import argparse
from torch.distributions import Normal from torch.distributions import Normal
from utils.file_utils import *
from utils.visualize import *
from model.pvcnn_generation import PVCNN2Base
from tqdm import tqdm from tqdm import tqdm
from datasets.shapenet_data_pc import ShapeNet15kPointClouds from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from metrics.evaluation_metrics import compute_all_metrics
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
from model.pvcnn_generation import PVCNN2Base
from utils.file_utils import *
from utils.visualize import *
''' """
models models
''' """
def normal_kl(mean1, logvar1, mean2, logvar2): def normal_kl(mean1, logvar1, mean2, logvar2):
""" """
KL divergence between normal distributions parameterized by mean and log-variance. KL divergence between normal distributions parameterized by mean and log-variance.
""" """
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2))
+ (mean1 - mean2)**2 * torch.exp(-logvar2))
def discretized_gaussian_log_likelihood(x, *, means, log_scales): def discretized_gaussian_log_likelihood(x, *, means, log_scales):
# Assumes data is integers [0, 1] # Assumes data is integers [0, 1]
@ -36,37 +35,40 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
inv_stdv = torch.exp(-log_scales) inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 0.5) plus_in = inv_stdv * (centered_x + 0.5)
cdf_plus = px0.cdf(plus_in) cdf_plus = px0.cdf(plus_in)
min_in = inv_stdv * (centered_x - .5) min_in = inv_stdv * (centered_x - 0.5)
cdf_min = px0.cdf(min_in) cdf_min = px0.cdf(min_in)
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12))
cdf_delta = cdf_plus - cdf_min cdf_delta = cdf_plus - cdf_min
log_probs = torch.where( log_probs = torch.where(
x < 0.001, log_cdf_plus, x < 0.001,
torch.where(x > 0.999, log_one_minus_cdf_min, log_cdf_plus,
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) torch.where(
x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))
),
)
assert log_probs.shape == x.shape assert log_probs.shape == x.shape
return log_probs return log_probs
class GaussianDiffusion: class GaussianDiffusion:
def __init__(self,betas, loss_type, model_mean_type, model_var_type): def __init__(self, betas, loss_type, model_mean_type, model_var_type):
self.loss_type = loss_type self.loss_type = loss_type
self.model_mean_type = model_mean_type self.model_mean_type = model_mean_type
self.model_var_type = model_var_type self.model_var_type = model_var_type
assert isinstance(betas, np.ndarray) assert isinstance(betas, np.ndarray)
self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all() assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape (timesteps,) = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
# initialize twice the actual length so we can keep running for eval # initialize twice the actual length so we can keep running for eval
# betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])])
alphas = 1. - betas alphas = 1.0 - betas
alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float()
alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float()
self.betas = torch.from_numpy(betas).float() self.betas = torch.from_numpy(betas).float()
self.alphas_cumprod = alphas_cumprod.float() self.alphas_cumprod = alphas_cumprod.float()
@ -74,21 +76,23 @@ class GaussianDiffusion:
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float()
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float()
self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float()
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float()
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float()
betas = torch.from_numpy(betas).float() betas = torch.from_numpy(betas).float()
alphas = torch.from_numpy(alphas).float() alphas = torch.from_numpy(alphas).float()
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.posterior_variance = posterior_variance self.posterior_variance = posterior_variance
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) self.posterior_log_variance_clipped = torch.log(
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) )
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
@staticmethod @staticmethod
def _extract(a, t, x_shape): def _extract(a, t, x_shape):
@ -96,17 +100,15 @@ class GaussianDiffusion:
Extract some coefficients at specified timesteps, Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
""" """
bs, = t.shape (bs,) = t.shape
assert x_shape[0] == bs assert x_shape[0] == bs
out = torch.gather(a, 0, t) out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs]) assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape)
return mean, variance, log_variance return mean, variance, log_variance
@ -118,56 +120,62 @@ class GaussianDiffusion:
noise = torch.randn(x_start.shape, device=x_start.device) noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape assert noise.shape == x_start.shape
return ( return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, x_start, x_t, t):
""" """
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
""" """
assert x_start.shape == x_t.shape assert x_start.shape == x_t.shape
posterior_mean = ( posterior_mean = (
self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start
self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
) )
posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) posterior_log_variance_clipped = self._extract(
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape
x_start.shape[0]) )
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t) model_output = denoise_fn(data, t)
if self.model_var_type in ["fixedsmall", "fixedlarge"]:
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations # below: only log_variance is used in the KL computations
model_variance, model_log_variance = { model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device), "fixedlarge": (
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), self.betas.to(data.device),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device),
),
"fixedsmall": (
self.posterior_variance.to(data.device),
self.posterior_log_variance_clipped.to(data.device),
),
}[self.model_var_type] }[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
else: else:
raise NotImplementedError(self.model_var_type) raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps': if self.model_mean_type == "eps":
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
if clip_denoised: if clip_denoised:
x_recon = torch.clamp(x_recon, -.5, .5) x_recon = torch.clamp(x_recon, -0.5, 0.5)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
else: else:
raise NotImplementedError(self.loss_type) raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape == data.shape assert model_mean.shape == x_recon.shape == data.shape
assert model_variance.shape == model_log_variance.shape == data.shape assert model_variance.shape == model_log_variance.shape == data.shape
if return_pred_xstart: if return_pred_xstart:
@ -178,18 +186,19 @@ class GaussianDiffusion:
def _predict_xstart_from_eps(self, x_t, t, eps): def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape assert x_t.shape == eps.shape
return ( return (
self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t
self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
) )
''' samples ''' """ samples """
def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True):
""" """
Sample from the model Sample from the model
""" """
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
return_pred_xstart=True) denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True
)
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
assert noise.shape == data.shape assert noise.shape == data.shape
# no noise when t == 0 # no noise when t == 0
@ -201,10 +210,17 @@ class GaussianDiffusion:
assert sample.shape == pred_xstart.shape assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(
def p_sample_loop(self, denoise_fn, shape, device, self,
noise_fn=torch.randn, constrain_fn=lambda x, t:x, denoise_fn,
clip_denoised=True, max_timestep=None, keep_running=False): shape,
device,
noise_fn=torch.randn,
constrain_fn=lambda x, t: x,
clip_denoised=True,
max_timestep=None,
keep_running=False,
):
""" """
Generate samples Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
@ -220,28 +236,38 @@ class GaussianDiffusion:
for t in reversed(range(0, final_time if not keep_running else len(self.betas))): for t in reversed(range(0, final_time if not keep_running else len(self.betas))):
img_t = constrain_fn(img_t, t) img_t = constrain_fn(img_t, t)
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, img_t = self.p_sample(
clip_denoised=clip_denoised, return_pred_xstart=False).detach() denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=clip_denoised,
return_pred_xstart=False,
).detach()
assert img_t.shape == shape assert img_t.shape == shape
return img_t return img_t
def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t: x):
assert t >= 1 assert t >= 1
t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t - 1)
encoding = self.q_sample(x0, t_vec) encoding = self.q_sample(x0, t_vec)
img_t = encoding img_t = encoding
for k in reversed(range(0,t)): for k in reversed(range(0, t)):
img_t = constrain_fn(img_t, k) img_t = constrain_fn(img_t, k)
t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, img_t = self.p_sample(
clip_denoised=False, return_pred_xstart=False, use_var=True).detach() denoise_fn=denoise_fn,
data=img_t,
t=t_,
noise_fn=noise_fn,
clip_denoised=False,
return_pred_xstart=False,
use_var=True,
).detach()
return img_t return img_t
@ -260,40 +286,50 @@ class PVCNN2(PVCNN2Base):
((128, 128, 64), (64, 2, 32)), ((128, 128, 64), (64, 2, 32)),
] ]
def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, def __init__(
voxel_resolution_multiplier=1): self,
num_classes,
embed_dim,
use_att,
dropout,
extra_feature_channels=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
):
super().__init__( super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, num_classes=num_classes,
dropout=dropout, extra_feature_channels=extra_feature_channels, embed_dim=embed_dim,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier use_att=use_att,
dropout=dropout,
extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
) )
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str):
super(Model, self).__init__() super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, self.model = PVCNN2(
dropout=args.dropout, extra_feature_channels=0) num_classes=args.nc,
embed_dim=args.embed_dim,
use_att=args.attention,
dropout=args.dropout,
extra_feature_channels=0,
)
def prior_kl(self, x0): def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0) return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, clip_denoised=True): def all_kl(self, x0, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt}
def _denoise(self, data, t): def _denoise(self, data, t):
B, D,N= data.shape B, D, N = data.shape
assert data.dtype == torch.float assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64 assert t.shape == torch.Size([B]) and t.dtype == torch.int64
@ -307,23 +343,34 @@ class Model(nn.Module):
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None: if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses( losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
denoise_fn=self._denoise, data_start=data, t=t, noise=noises)
assert losses.shape == t.shape == torch.Size([B]) assert losses.shape == t.shape == torch.Size([B])
return losses return losses
def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, def gen_samples(
clip_denoised=False, max_timestep=None, self,
keep_running=False): shape,
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, device,
constrain_fn=constrain_fn, noise_fn=torch.randn,
clip_denoised=clip_denoised, max_timestep=max_timestep, constrain_fn=lambda x, t: x,
keep_running=keep_running) clip_denoised=False,
max_timestep=None,
def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): keep_running=False,
):
return self.diffusion.p_sample_loop(
self._denoise,
shape=shape,
device=device,
noise_fn=noise_fn,
constrain_fn=constrain_fn,
clip_denoised=clip_denoised,
max_timestep=max_timestep,
keep_running=keep_running,
)
def reconstruct(self, x0, t, constrain_fn=lambda x, t: x):
return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn)
def train(self): def train(self):
@ -337,20 +384,17 @@ class Model(nn.Module):
def get_betas(schedule_type, b_start, b_end, time_num): def get_betas(schedule_type, b_start, b_end, time_num):
if schedule_type == 'linear': if schedule_type == "linear":
betas = np.linspace(b_start, b_end, time_num) betas = np.linspace(b_start, b_end, time_num)
elif schedule_type == 'warm0.1': elif schedule_type == "warm0.1":
betas = b_end * np.ones(time_num, dtype=np.float64) betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.1) warmup_time = int(time_num * 0.1)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.2': elif schedule_type == "warm0.2":
betas = b_end * np.ones(time_num, dtype=np.float64) betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.2) warmup_time = int(time_num * 0.2)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
elif schedule_type == 'warm0.5': elif schedule_type == "warm0.5":
betas = b_end * np.ones(time_num, dtype=np.float64) betas = b_end * np.ones(time_num, dtype=np.float64)
warmup_time = int(time_num * 0.5) warmup_time = int(time_num * 0.5)
betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64)
@ -358,111 +402,109 @@ def get_betas(schedule_type, b_start, b_end, time_num):
raise NotImplementedError(schedule_type) raise NotImplementedError(schedule_type)
return betas return betas
def get_constrain_function(ground_truth, mask, eps, num_steps=1): def get_constrain_function(ground_truth, mask, eps, num_steps=1):
''' """
:param target_shape_constraint: target voxels :param target_shape_constraint: target voxels
:return: constrained x :return: constrained x
''' """
# eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2))
eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000) ** 2))
def constrain_fn(x, t):
eps_ = eps_all[t] if (t<1000) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
def constrain_fn(x, t):
eps_ = eps_all[t] if (t < 1000) else 0
for _ in range(num_steps):
x = x - eps_ * ((x - ground_truth) * mask)
return x return x
return constrain_fn return constrain_fn
############################################################################# #############################################################################
def get_dataset(dataroot, npoints,category,use_mask=False):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, def get_dataset(dataroot, npoints, category, use_mask=False):
categories=[category], split='train', tr_dataset = ShapeNet15kPointClouds(
root_dir=dataroot,
categories=[category],
split="train",
tr_sample_size=npoints, tr_sample_size=npoints,
te_sample_size=npoints, te_sample_size=npoints,
scale=1., scale=1.0,
normalize_per_shape=False, normalize_per_shape=False,
normalize_std_per_axis=False, normalize_std_per_axis=False,
random_subsample=True, use_mask = use_mask) random_subsample=True,
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, use_mask=use_mask,
categories=[category], split='val', )
te_dataset = ShapeNet15kPointClouds(
root_dir=dataroot,
categories=[category],
split="val",
tr_sample_size=npoints, tr_sample_size=npoints,
te_sample_size=npoints, te_sample_size=npoints,
scale=1., scale=1.0,
normalize_per_shape=False, normalize_per_shape=False,
normalize_std_per_axis=False, normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean, all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std, all_points_std=tr_dataset.all_points_std,
use_mask=use_mask use_mask=use_mask,
) )
return tr_dataset, te_dataset return tr_dataset, te_dataset
def evaluate_gen(opt, ref_pcs, logger): def evaluate_gen(opt, ref_pcs, logger):
if ref_pcs is None: if ref_pcs is None:
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False) _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, test_dataloader = torch.utils.data.DataLoader(
shuffle=False, num_workers=int(opt.workers), drop_last=False) test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
ref = [] ref = []
for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
x = data['test_points'] x = data["test_points"]
m, s = data['mean'].float(), data['std'].float() m, s = data["mean"].float(), data["std"].float()
ref.append(x*s + m) ref.append(x * s + m)
ref_pcs = torch.cat(ref, dim=0).contiguous() ref_pcs = torch.cat(ref, dim=0).contiguous()
logger.info("Loading sample path: %s" logger.info("Loading sample path: %s" % (opt.eval_path))
% (opt.eval_path))
sample_pcs = torch.load(opt.eval_path).contiguous() sample_pcs = torch.load(opt.eval_path).contiguous()
logger.info("Generation sample size:%s reference size: %s" logger.info("Generation sample size:%s reference size: %s" % (sample_pcs.size(), ref_pcs.size()))
% (sample_pcs.size(), ref_pcs.size()))
# Compute metrics # Compute metrics
results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size)
results = {k: (v.cpu().detach().item() results = {k: (v.cpu().detach().item() if not isinstance(v, float) else v) for k, v in results.items()}
if not isinstance(v, float) else v) for k, v in results.items()}
pprint(results) pprint(results)
logger.info(results) logger.info(results)
jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy())
pprint('JSD: {}'.format(jsd)) pprint("JSD: {}".format(jsd))
logger.info('JSD: {}'.format(jsd)) logger.info("JSD: {}".format(jsd))
def generate(model, opt): def generate(model, opt):
_, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category) _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, test_dataloader = torch.utils.data.DataLoader(
shuffle=False, num_workers=int(opt.workers), drop_last=False) test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
)
with torch.no_grad(): with torch.no_grad():
samples = [] samples = []
ref = [] ref = []
for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Generating Samples"):
x = data["test_points"].transpose(1, 2)
x = data['test_points'].transpose(1,2) m, s = data["mean"].float(), data["std"].float()
m, s = data['mean'].float(), data['std'].float()
gen = model.gen_samples(x.shape,
'cuda', clip_denoised=False).detach().cpu()
gen = gen.transpose(1,2).contiguous()
x = x.transpose(1,2).contiguous()
gen = model.gen_samples(x.shape, "cuda", clip_denoised=False).detach().cpu()
gen = gen.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous()
gen = gen * s + m gen = gen * s + m
x = x * s + m x = x * s + m
@ -482,20 +524,20 @@ def generate(model, opt):
# 1, # 1,
# 0.5, # 0.5,
# ) # )
# visualize using matplotlib # visualize using matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib import matplotlib
matplotlib.use('TkAgg') import matplotlib.cm as cm
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")
for idx, pc in enumerate(gen[:64]): for idx, pc in enumerate(gen[:64]):
print(f"Visualizing point cloud {idx}...") print(f"Visualizing point cloud {idx}...")
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection="3d")
ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,2], cmap=cm.jet) ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
ax.set_aspect('equal') ax.set_aspect("equal")
ax.axis('off') ax.axis("off")
# ax.set_xlim(-1, 1) # ax.set_xlim(-1, 1)
# ax.set_ylim(-1, 1) # ax.set_ylim(-1, 1)
# ax.set_zlim(-1, 1) # ax.set_zlim(-1, 1)
@ -507,17 +549,14 @@ def generate(model, opt):
torch.save(samples, opt.eval_path) torch.save(samples, opt.eval_path)
return ref return ref
def main(opt): def main(opt):
if opt.category == "airplane":
if opt.category == 'airplane':
opt.beta_start = 1e-5 opt.beta_start = 1e-5
opt.beta_end = 0.008 opt.beta_end = 0.008
opt.schedule_type = 'warm0.1' opt.schedule_type = "warm0.1"
exp_id = os.path.splitext(os.path.basename(__file__))[0] exp_id = os.path.splitext(os.path.basename(__file__))[0]
dir_id = os.path.dirname(__file__) dir_id = os.path.dirname(__file__)
@ -525,7 +564,7 @@ def main(opt):
copy_source(__file__, output_dir) copy_source(__file__, output_dir)
logger = setup_logging(output_dir) logger = setup_logging(output_dir)
outf_syn, = setup_output_subdirs(output_dir, 'syn') (outf_syn,) = setup_output_subdirs(output_dir, "syn")
betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num)
model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type)
@ -542,64 +581,59 @@ def main(opt):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logger.info("Resume Path:%s" % opt.model) logger.info("Resume Path:%s" % opt.model)
resumed_param = torch.load(opt.model) resumed_param = torch.load(opt.model)
model.load_state_dict(resumed_param['model_state']) model.load_state_dict(resumed_param["model_state"])
ref = None ref = None
if opt.generate: if opt.generate:
opt.eval_path = os.path.join(outf_syn, 'samples.pth') opt.eval_path = os.path.join(outf_syn, "samples.pth")
Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True)
ref=generate(model, opt) ref = generate(model, opt)
if opt.eval_gen: if opt.eval_gen:
# Evaluate generation # Evaluate generation
evaluate_gen(opt, ref, logger) evaluate_gen(opt, ref, logger)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/') parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/")
parser.add_argument('--category', default='chair') parser.add_argument("--category", default="chair")
parser.add_argument('--batch_size', type=int, default=50, help='input batch size') parser.add_argument("--batch_size", type=int, default=50, help="input batch size")
parser.add_argument('--workers', type=int, default=16, help='workers') parser.add_argument("--workers", type=int, default=16, help="workers")
parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for")
parser.add_argument('--generate',default=True) parser.add_argument("--generate", default=True)
parser.add_argument('--eval_gen', default=True) parser.add_argument("--eval_gen", default=True)
parser.add_argument('--nc', default=3) parser.add_argument("--nc", default=3)
parser.add_argument('--npoints', default=2048) parser.add_argument("--npoints", default=2048)
'''model''' """model"""
parser.add_argument('--beta_start', default=0.0001) parser.add_argument("--beta_start", default=0.0001)
parser.add_argument('--beta_end', default=0.02) parser.add_argument("--beta_end", default=0.02)
parser.add_argument('--schedule_type', default='linear') parser.add_argument("--schedule_type", default="linear")
parser.add_argument('--time_num', default=1000) parser.add_argument("--time_num", default=1000)
#params # params
parser.add_argument('--attention', default=True) parser.add_argument("--attention", default=True)
parser.add_argument('--dropout', default=0.1) parser.add_argument("--dropout", default=0.1)
parser.add_argument('--embed_dim', type=int, default=64) parser.add_argument("--embed_dim", type=int, default=64)
parser.add_argument('--loss_type', default='mse') parser.add_argument("--loss_type", default="mse")
parser.add_argument('--model_mean_type', default='eps') parser.add_argument("--model_mean_type", default="eps")
parser.add_argument('--model_var_type', default='fixedsmall') parser.add_argument("--model_var_type", default="fixedsmall")
parser.add_argument("--model", default="", required=True, help="path to model (to continue training)")
parser.add_argument('--model', default='',required=True, help="path to model (to continue training)") """eval"""
'''eval''' parser.add_argument("--eval_path", default="")
parser.add_argument('--eval_path', parser.add_argument("--manualSeed", default=42, type=int, help="random seed")
default='')
parser.add_argument('--manualSeed', default=42, type=int, help='random seed') parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)")
parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)')
opt = parser.parse_args() opt = parser.parse_args()
@ -609,7 +643,9 @@ def parse_args():
opt.cuda = False opt.cuda = False
return opt return opt
if __name__ == '__main__':
if __name__ == "__main__":
opt = parse_args() opt = parse_args()
set_seed(opt) set_seed(opt)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,35 +1,33 @@
import datetime
import logging
import os import os
import random import random
import sys import sys
from shutil import copyfile from shutil import copyfile
import datetime
import torch import torch
import logging
logger = logging.getLogger() logger = logging.getLogger()
import numpy as np import numpy as np
def set_global_gpu_env(opt): def set_global_gpu_env(opt):
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)
torch.cuda.set_device(opt.gpu) torch.cuda.set_device(opt.gpu)
def copy_source(file, output_dir): def copy_source(file, output_dir):
copyfile(file, os.path.join(output_dir, os.path.basename(file))) copyfile(file, os.path.join(output_dir, os.path.basename(file)))
def setup_logging(output_dir): def setup_logging(output_dir):
log_format = logging.Formatter("%(asctime)s : %(message)s") log_format = logging.Formatter("%(asctime)s : %(message)s")
logger = logging.getLogger() logger = logging.getLogger()
logger.handlers = [] logger.handlers = []
output_file = os.path.join(output_dir, 'output.log') output_file = os.path.join(output_dir, "output.log")
file_handler = logging.FileHandler(output_file) file_handler = logging.FileHandler(output_file)
file_handler.setFormatter(log_format) file_handler.setFormatter(log_format)
logger.addHandler(file_handler) logger.addHandler(file_handler)
@ -44,16 +42,14 @@ def setup_logging(output_dir):
def get_output_dir(prefix, exp_id): def get_output_dir(prefix, exp_id):
t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
output_dir = os.path.join(prefix, 'output/' + exp_id, t) output_dir = os.path.join(prefix, "output/" + exp_id, t)
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
return output_dir return output_dir
def set_seed(opt): def set_seed(opt):
if opt.manualSeed is None: if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000) opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed) print("Random Seed: ", opt.manualSeed)
@ -65,8 +61,8 @@ def set_seed(opt):
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
def setup_output_subdirs(output_dir, *subfolders):
def setup_output_subdirs(output_dir, *subfolders):
output_subdirs = output_dir output_subdirs = output_dir
try: try:
os.makedirs(output_subdirs) os.makedirs(output_subdirs)
@ -82,4 +78,4 @@ def setup_output_subdirs(output_dir, *subfolders):
pass pass
subfolder_list.append(curr_subf) subfolder_list.append(curr_subf)
return subfolder_list return subfolder_list

View file

@ -1,20 +1,22 @@
import numpy as np
import warnings import warnings
import numpy as np
from scipy.stats import entropy from scipy.stats import entropy
def iterate_in_chunks(l, n): def iterate_in_chunks(l, n):
'''Yield successive 'n'-sized chunks from iterable 'l'. """Yield successive 'n'-sized chunks from iterable 'l'.
Note: last chunk will be smaller than l if n doesn't divide l perfectly. Note: last chunk will be smaller than l if n doesn't divide l perfectly.
''' """
for i in range(0, len(l), n): for i in range(0, len(l), n):
yield l[i:i + n] yield l[i : i + n]
def unit_cube_grid_point_cloud(resolution, clip_sphere=False): def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
'''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
that is placed in the unit-cube. that is placed in the unit-cube.
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
''' """
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
spacing = 1.0 / float(resolution - 1) spacing = 1.0 / float(resolution - 1)
for i in range(resolution): for i in range(resolution):
@ -30,9 +32,11 @@ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
return grid, spacing return grid, spacing
def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False,
use_EMD=False): def minimum_mathing_distance(
'''Computes the MMD between two sets of point-clouds. sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False
):
"""Computes the MMD between two sets of point-clouds.
Args: Args:
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and
compared to a set of "reference" point-clouds. compared to a set of "reference" point-clouds.
@ -49,17 +53,17 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
use_EMD (boolean: If true, the matchings are based on the EMD. use_EMD (boolean: If true, the matchings are based on the EMD.
Returns: Returns:
A tuple containing the MMD and all the matched distances of which the MMD is their mean. A tuple containing the MMD and all the matched distances of which the MMD is their mean.
''' """
n_ref, n_pc_points, pc_dim = ref_pcs.shape n_ref, n_pc_points, pc_dim = ref_pcs.shape
_, n_pc_points_s, pc_dim_s = sample_pcs.shape _, n_pc_points_s, pc_dim_s = sample_pcs.shape
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
raise ValueError('Incompatible size of point-clouds.') raise ValueError("Incompatible size of point-clouds.")
ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize, ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(
sess=sess, use_sqrt=use_sqrt, n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
use_EMD=use_EMD) )
matched_dists = [] matched_dists = []
for i in range(n_ref): for i in range(n_ref):
best_in_all_batches = [] best_in_all_batches = []
@ -75,9 +79,18 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se
return mmd, matched_dists return mmd, matched_dists
def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False, def coverage(
ret_dist=False): sample_pcs,
'''Computes the Coverage between two sets of point-clouds. ref_pcs,
batch_size,
normalize=True,
sess=None,
verbose=False,
use_sqrt=False,
use_EMD=False,
ret_dist=False,
):
"""Computes the Coverage between two sets of point-clouds.
Args: Args:
sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched
and compared to a set of "reference" point-clouds. and compared to a set of "reference" point-clouds.
@ -97,18 +110,16 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
Returns: the coverage score (int), Returns: the coverage score (int),
the indices of the ref_pcs that are matched with each sample_pc the indices of the ref_pcs that are matched with each sample_pc
and optionally the matched distances of the samples_pcs. and optionally the matched distances of the samples_pcs.
''' """
n_ref, n_pc_points, pc_dim = ref_pcs.shape n_ref, n_pc_points, pc_dim = ref_pcs.shape
n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape
if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s:
raise ValueError('Incompatible Point-Clouds.') raise ValueError("Incompatible Point-Clouds.")
ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points, ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(
normalize=normalize, n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD
sess=sess, )
use_sqrt=use_sqrt,
use_EMD=use_EMD)
matched_gt = [] matched_gt = []
matched_dist = [] matched_dist = []
for i in xrange(n_sam): for i in xrange(n_sam):
@ -140,12 +151,12 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
'''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```. """Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```.
Args: Args:
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
resolution: (int) grid-resolution. Affects granularity of measurements. resolution: (int) grid-resolution. Affects granularity of measurements.
''' """
in_unit_sphere = True in_unit_sphere = True
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
@ -153,19 +164,19 @@ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
'''Given a collection of point-clouds, estimate the entropy of the random variables """Given a collection of point-clouds, estimate the entropy of the random variables
corresponding to occupancy-grid activation patterns. corresponding to occupancy-grid activation patterns.
Inputs: Inputs:
pclouds: (numpy array) #point-clouds x points per point-cloud x 3 pclouds: (numpy array) #point-clouds x points per point-cloud x 3
grid_resolution (int) size of occupancy grid that will be used. grid_resolution (int) size of occupancy grid that will be used.
''' """
epsilon = 10e-4 epsilon = 10e-4
bound = 0.5 + epsilon bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
warnings.warn('Point-clouds are not in unit cube.') warnings.warn("Point-clouds are not in unit cube.")
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound:
warnings.warn('Point-clouds are not in unit sphere.') warnings.warn("Point-clouds are not in unit sphere.")
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3) grid_coordinates = grid_coordinates.reshape(-1, 3)
@ -192,13 +203,14 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False):
return acc_entropy / len(grid_counters), grid_counters return acc_entropy / len(grid_counters), grid_counters
def jensen_shannon_divergence(P, Q): def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0): if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.') raise ValueError("Negative values.")
if len(P) != len(Q): if len(P) != len(Q):
raise ValueError('Non equal size.') raise ValueError("Non equal size.")
P_ = P / np.sum(P) # Ensure probabilities. P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q) Q_ = Q / np.sum(Q)
e1 = entropy(P_, base=2) e1 = entropy(P_, base=2)
@ -209,13 +221,14 @@ def jensen_shannon_divergence(P, Q):
res2 = _jsdiv(P_, Q_) res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0): if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.') warnings.warn("Numerical values of two JSD methods don't agree.")
return res return res
def _jsdiv(P, Q): def _jsdiv(P, Q):
'''another way of computing JSD''' """another way of computing JSD"""
def _kldiv(A, B): def _kldiv(A, B):
a = A.copy() a = A.copy()
b = B.copy() b = B.copy()
@ -229,4 +242,4 @@ def _jsdiv(P, Q):
M = 0.5 * (P_ + Q_) M = 0.5 * (P_ + Q_)
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))

View file

@ -1,32 +1,33 @@
import matplotlib import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt matplotlib.use("agg")
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np
import os import os
import trimesh
from pathlib import Path from pathlib import Path
''' import matplotlib.pyplot as plt
import numpy as np
import trimesh
from mpl_toolkits.mplot3d import Axes3D
"""
Custom visualization Custom visualization
''' """
def export_to_pc_batch(dir, pcs, colors=None): def export_to_pc_batch(dir, pcs, colors=None):
Path(dir).mkdir(parents=True, exist_ok=True) Path(dir).mkdir(parents=True, exist_ok=True)
for i, xyz in enumerate(pcs): for i, xyz in enumerate(pcs):
if colors is None: if colors is None:
color = None color = None
else: else:
color = colors[i] color = colors[i]
pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color) pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color)
def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)): def export_to_obj(dir, meshes, transform=lambda v, f: (v, f)):
''' """
transform: f(vertices, faces) --> transformed (vertices, faces) transform: f(vertices, faces) --> transformed (vertices, faces)
''' """
Path(dir).mkdir(parents=True, exist_ok=True) Path(dir).mkdir(parents=True, exist_ok=True)
for i, data in enumerate(meshes): for i, data in enumerate(meshes):
v, f = transform(data[0], data[1]) v, f = transform(data[0], data[1])
@ -36,14 +37,15 @@ def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)):
v_color = None v_color = None
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color) mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
out = trimesh.exchange.obj.export_obj(mesh) out = trimesh.exchange.obj.export_obj(mesh)
with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f: with open(os.path.join(dir, "sample_" + str(i) + ".obj"), "w") as f:
f.write(out) f.write(out)
f.close() f.close()
def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
''' def export_to_obj_single(path, data, transform=lambda v, f: (v, f)):
"""
transform: f(vertices, faces) --> transformed (vertices, faces) transform: f(vertices, faces) --> transformed (vertices, faces)
''' """
v, f = transform(data[0], data[1]) v, f = transform(data[0], data[1])
if len(data) > 2: if len(data) > 2:
v_color = data[2] v_color = data[2]
@ -51,15 +53,15 @@ def export_to_obj_single(path, data, transform=lambda v,f:(v,f)):
v_color = None v_color = None
mesh = trimesh.Trimesh(v, f, vertex_colors=v_color) mesh = trimesh.Trimesh(v, f, vertex_colors=v_color)
out = trimesh.exchange.obj.export_obj(mesh) out = trimesh.exchange.obj.export_obj(mesh)
with open(path, 'w') as f: with open(path, "w") as f:
f.write(out) f.write(out)
f.close() f.close()
def meshwrite(filename, verts, faces, norms, colors): def meshwrite(filename, verts, faces, norms, colors):
"""Save a 3D mesh to a polygon .ply file. """Save a 3D mesh to a polygon .ply file."""
"""
# Write header # Write header
ply_file = open(filename, 'w') ply_file = open(filename, "w")
ply_file.write("ply\n") ply_file.write("ply\n")
ply_file.write("format ascii 1.0\n") ply_file.write("format ascii 1.0\n")
ply_file.write("element vertex %d\n" % (verts.shape[0])) ply_file.write("element vertex %d\n" % (verts.shape[0]))
@ -78,11 +80,20 @@ def meshwrite(filename, verts, faces, norms, colors):
# Write vertex list # Write vertex list
for i in range(verts.shape[0]): for i in range(verts.shape[0]):
ply_file.write("%f %f %f %f %f %f %d %d %d\n" % ( ply_file.write(
verts[i, 0], verts[i, 1], verts[i, 2], "%f %f %f %f %f %f %d %d %d\n"
norms[i, 0], norms[i, 1], norms[i, 2], % (
colors[i, 0], colors[i, 1], colors[i, 2], verts[i, 0],
)) verts[i, 1],
verts[i, 2],
norms[i, 0],
norms[i, 1],
norms[i, 2],
colors[i, 0],
colors[i, 1],
colors[i, 2],
)
)
# Write face list # Write face list
for i in range(faces.shape[0]): for i in range(faces.shape[0]):
@ -92,14 +103,13 @@ def meshwrite(filename, verts, faces, norms, colors):
def pcwrite(filename, xyz, rgb=None): def pcwrite(filename, xyz, rgb=None):
"""Save a point cloud to a polygon .ply file. """Save a point cloud to a polygon .ply file."""
"""
if rgb is None: if rgb is None:
rgb = np.ones_like(xyz) * 128 rgb = np.ones_like(xyz) * 128
rgb = rgb.astype(np.uint8) rgb = rgb.astype(np.uint8)
# Write header # Write header
ply_file = open(filename, 'w') ply_file = open(filename, "w")
ply_file.write("ply\n") ply_file.write("ply\n")
ply_file.write("format ascii 1.0\n") ply_file.write("format ascii 1.0\n")
ply_file.write("element vertex %d\n" % (xyz.shape[0])) ply_file.write("element vertex %d\n" % (xyz.shape[0]))
@ -113,60 +123,67 @@ def pcwrite(filename, xyz, rgb=None):
# Write vertex list # Write vertex list
for i in range(xyz.shape[0]): for i in range(xyz.shape[0]):
ply_file.write("%f %f %f %d %d %d\n" % ( ply_file.write(
xyz[i, 0], xyz[i, 1], xyz[i, 2], "%f %f %f %d %d %d\n"
rgb[i, 0], rgb[i, 1], rgb[i, 2], % (
)) xyz[i, 0],
xyz[i, 1],
xyz[i, 2],
rgb[i, 0],
rgb[i, 1],
rgb[i, 2],
)
)
'''
"""
Matplotlib Visualization Matplotlib Visualization
''' """
def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5): def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
r''' Visualizes voxel data. r"""Visualizes voxel data.
show only first num_shown show only first num_shown
''' """
batch_size =voxels.shape[0] batch_size = voxels.shape[0]
voxels = voxels.squeeze(1) > threshold voxels = voxels.squeeze(1) > threshold
num_shown = min(num_shown, batch_size) num_shown = min(num_shown, batch_size)
n = int(np.sqrt(num_shown)) n = int(np.sqrt(num_shown))
fig = plt.figure(figsize=(20,20)) fig = plt.figure(figsize=(20, 20))
for idx, pc in enumerate(voxels[:num_shown]): for idx, pc in enumerate(voxels[:num_shown]):
if idx >= n*n: if idx >= n * n:
break break
pc = voxels[idx] pc = voxels[idx]
ax = fig.add_subplot(n, n, idx + 1, projection='3d') ax = fig.add_subplot(n, n, idx + 1, projection="3d")
ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5) ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5)
ax.view_init() ax.view_init()
ax.axis('off') ax.axis("off")
plt.savefig(out_file, bbox_inches='tight') plt.savefig(out_file, bbox_inches="tight")
plt.close() plt.close()
def visualize_pointcloud(points, normals=None,
out_file=None, show=False, elev=30, azim=225): def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225):
r''' Visualizes point cloud data. r"""Visualizes point cloud data.
Args: Args:
points (tensor): point data points (tensor): point data
normals (tensor): normal data (if existing) normals (tensor): normal data (if existing)
out_file (string): output file out_file (string): output file
show (bool): whether the plot should be shown show (bool): whether the plot should be shown
''' """
# Create plot # Create plot
fig = plt.figure() fig = plt.figure()
ax = fig.gca(projection=Axes3D.name) ax = fig.gca(projection=Axes3D.name)
ax.scatter(points[:, 2], points[:, 0], points[:, 1]) ax.scatter(points[:, 2], points[:, 0], points[:, 1])
if normals is not None: if normals is not None:
ax.quiver( ax.quiver(
points[:, 2], points[:, 0], points[:, 1], points[:, 2], points[:, 0], points[:, 1], normals[:, 2], normals[:, 0], normals[:, 1], length=0.1, color="k"
normals[:, 2], normals[:, 0], normals[:, 1],
length=0.1, color='k'
) )
ax.set_xlabel('Z') ax.set_xlabel("Z")
ax.set_ylabel('X') ax.set_ylabel("X")
ax.set_zlabel('Y') ax.set_zlabel("Y")
# ax.set_xlim(-0.5, 0.5) # ax.set_xlim(-0.5, 0.5)
# ax.set_ylim(-0.5, 0.5) # ax.set_ylim(-0.5, 0.5)
# ax.set_zlim(-0.5, 0.5) # ax.set_zlim(-0.5, 0.5)
@ -178,37 +195,39 @@ def visualize_pointcloud(points, normals=None,
plt.close(fig) plt.close(fig)
def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225): def visualize_pointcloud_batch(
path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225
):
batch_size = len(pointclouds) batch_size = len(pointclouds)
fig = plt.figure(figsize=(20,20)) fig = plt.figure(figsize=(20, 20))
ncols = int(np.sqrt(batch_size)) ncols = int(np.sqrt(batch_size))
nrows = max(1, (batch_size-1) // ncols+1) nrows = max(1, (batch_size - 1) // ncols + 1)
for idx, pc in enumerate(pointclouds): for idx, pc in enumerate(pointclouds):
if vis_label: if vis_label:
label = categories[labels[idx].item()] label = categories[labels[idx].item()]
pred = categories[pred_labels[idx]] pred = categories[pred_labels[idx]]
colour = 'g' if label == pred else 'r' colour = "g" if label == pred else "r"
elif target is None: elif target is None:
colour = "g"
colour = 'g'
else: else:
colour = target[idx] colour = target[idx]
pc = pc.cpu().numpy() pc = pc.cpu().numpy()
ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d') ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d")
ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5) ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
ax.view_init(elev=elev, azim=azim) ax.view_init(elev=elev, azim=azim)
ax.axis('off') ax.axis("off")
if vis_label: if vis_label:
ax.set_title('GT: {0}\nPred: {1}'.format(label, pred)) ax.set_title("GT: {0}\nPred: {1}".format(label, pred))
plt.savefig(path) plt.savefig(path)
plt.close(fig) plt.close(fig)
''' """
Plot stats Plot stats
''' """
def plot_stats(output_dir, stats, interval): def plot_stats(output_dir, stats, interval):
content = stats.keys() content = stats.keys()
@ -218,5 +237,5 @@ def plot_stats(output_dir, stats, interval):
axs[j].plot(interval, v) axs[j].plot(interval, v)
axs[j].set_ylabel(k) axs[j].set_ylabel(k)
f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight') f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight")
plt.close(f) plt.close(f)