diff --git a/convert_cam_params.py b/convert_cam_params.py index 36abb2b..d5e144d 100644 --- a/convert_cam_params.py +++ b/convert_cam_params.py @@ -1,32 +1,33 @@ - -from glob import glob -import re import argparse -import numpy as np -from pathlib import Path import os +import re +from glob import glob +from pathlib import Path + +import numpy as np + def raw_camparam_from_xml(path, pose="lookAt"): import xml.etree.ElementTree as ET + tree = ET.parse(path) elm = tree.find("./sensor/transform/" + pose) camparam = elm.attrib - origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',') - target = np.fromstring(camparam['target'], dtype=np.float32, sep=',') - up = np.fromstring(camparam['up'], dtype=np.float32, sep=',') - height = int( - tree.find("./sensor/film/integer[@name='height']").attrib['value']) - width = int( - tree.find("./sensor/film/integer[@name='width']").attrib['value']) + origin = np.fromstring(camparam["origin"], dtype=np.float32, sep=",") + target = np.fromstring(camparam["target"], dtype=np.float32, sep=",") + up = np.fromstring(camparam["up"], dtype=np.float32, sep=",") + height = int(tree.find("./sensor/film/integer[@name='height']").attrib["value"]) + width = int(tree.find("./sensor/film/integer[@name='width']").attrib["value"]) camparam = dict() - camparam['origin'] = origin - camparam['up'] = up - camparam['target'] = target - camparam['height'] = height - camparam['width'] = width + camparam["origin"] = origin + camparam["up"] = up + camparam["target"] = target + camparam["height"] = height + camparam["width"] = width return camparam + def get_cam_pos(origin, target, up): inward = origin - target right = np.cross(up, inward) @@ -38,59 +39,54 @@ def get_cam_pos(origin, target, up): ry /= np.linalg.norm(ry) rz /= np.linalg.norm(rz) - rot = np.stack([ - rx, - ry, - -rz - ], axis=0) - - - aff = np.concatenate([ - np.eye(3), -origin[:,None] - ], axis=1) + rot = np.stack([rx, ry, -rz], axis=0) + aff = np.concatenate([np.eye(3), -origin[:, None]], axis=1) ext = np.matmul(rot, aff) - result = np.concatenate( - [ext, np.array([[0,0,0,1]])], axis=0 - ) - - + result = np.concatenate([ext, np.array([[0, 0, 0, 1]])], axis=0) return result - def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir): - depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png'))) - cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths] - + depths = sorted(glob(os.path.join(datapoint_dir, "*depth.png"))) + cam_ext = [ + "_".join(re.sub(dataroot.strip("/"), camera_param_dir.strip("/"), f).split("_")[:-1]) + ".xml" for f in depths + ] for i, (f, pth) in enumerate(zip(cam_ext, depths)): if not os.path.exists(f): continue - params=raw_camparam_from_xml(f) - origin, target, up, width, height = params['origin'], params['target'], params['up'],\ - params['width'], params['height'] + params = raw_camparam_from_xml(f) + origin, target, up, width, height = ( + params["origin"], + params["target"], + params["up"], + params["width"], + params["height"], + ) ext_matrix = get_cam_pos(origin, target, up) ##### - diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5 + diag = (0.036**2 + 0.024**2) ** 0.5 focal_length = 0.05 res = [480, 480] - h_relative = (res[1] / res[0]) - sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2)) + h_relative = res[1] / res[0] + sensor_width = np.sqrt(diag**2 / (1 + h_relative**2)) pix_size = sensor_width / res[0] - K = np.array([ - [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2], - [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2], - [0, 0, 1] - ]) + K = np.array( + [ + [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2], + [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2], + [0, 0, 1], + ] + ) - np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K) + np.savez(pth.split("depth.png")[0] + "cam_params.npz", extr=ext_matrix, intr=K) def main(opt): @@ -102,21 +98,16 @@ def main(opt): if (not dirnames) and opt.mitsuba_xml_root not in dirpath: leaf_subdirs.append(dirpath) - - for k, dir_ in enumerate(leaf_subdirs): - print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_)) + print("Processing dir {}/{}: {}".format(k, len(leaf_subdirs), dir_)) convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root) - - - -if __name__ == '__main__': +if __name__ == "__main__": args = argparse.ArgumentParser() - args.add_argument('--dataroot', type=str, default='GenReData/') - args.add_argument('--mitsuba_xml_root', type=str, default='GenReData/genre-xml_v2') + args.add_argument("--dataroot", type=str, default="GenReData/") + args.add_argument("--mitsuba_xml_root", type=str, default="GenReData/genre-xml_v2") opt = args.parse_args() diff --git a/datasets/partnet.py b/datasets/partnet.py index f074663..2bf68b4 100644 --- a/datasets/partnet.py +++ b/datasets/partnet.py @@ -1,11 +1,13 @@ -from torch.utils.data import Dataset, DataLoader -import torch -import numpy as np -import os import json +import os import random + +import numpy as np +import torch import trimesh from plyfile import PlyData, PlyElement +from torch.utils.data import Dataset + def project_pc_to_image(points, resolution=64): """project point clouds into 2D image @@ -26,29 +28,32 @@ def project_pc_to_image(points, resolution=64): def write_ply(points, filename, text=False): - """ input: Nx3, write points to filename as PLY format. """ - points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] - vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) - el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) - with open(filename, mode='wb') as f: + """input: Nx3, write points to filename as PLY format.""" + points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])] + vertex = np.array(points, dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) + el = PlyElement.describe(vertex, "vertex", comments=["vertices"]) + with open(filename, mode="wb") as f: PlyData([el], text=text).write(f) def rotate_point_cloud(points, transformation_mat): - new_points = np.dot(transformation_mat, points.T).T return new_points def rotate_point_cloud_by_axis_angle(points, axis, angle_deg): - """ align 3depn shapes to shapenet coordinates""" + """align 3depn shapes to shapenet coordinates""" # angle = math.radians(angle_deg) # rot_m = pymesh.Quaternion.fromAxisAngle(axis, angle) # rot_m = rot_m.to_matrix() - rot_m = np.array([[ 2.22044605e-16, 0.00000000e+00, 1.00000000e+00], - [ 0.00000000e+00, 1.00000000e+00, 0.00000000e+00], - [-1.00000000e+00, 0.00000000e+00, 2.22044605e-16]]) + rot_m = np.array( + [ + [2.22044605e-16, 0.00000000e00, 1.00000000e00], + [0.00000000e00, 1.00000000e00, 0.00000000e00], + [-1.00000000e00, 0.00000000e00, 2.22044605e-16], + ] + ) new_points = rotate_point_cloud(points, rot_m) @@ -87,14 +92,13 @@ def sample_point_cloud_by_n(points, n_pts): return points - def collect_data_id(split_dir, classname, phase): filename = os.path.join(split_dir, "{}.{}.json".format(classname, phase)) if not os.path.exists(filename): raise ValueError("Invalid filepath: {}".format(filename)) all_ids = [] - with open(filename, 'r') as fp: + with open(filename, "r") as fp: info = json.load(fp) for item in info: all_ids.append(item["anno_id"]) @@ -102,7 +106,6 @@ def collect_data_id(split_dir, classname, phase): return all_ids - class GANdatasetPartNet(Dataset): def __init__(self, phase, data_root, category, n_pts): super(GANdatasetPartNet, self).__init__() @@ -114,10 +117,12 @@ class GANdatasetPartNet(Dataset): self.data_root = data_root - shape_names = collect_data_id(os.path.join(self.data_root, 'partnet_labels/partnet_train_val_test_split'), category, phase) + shape_names = collect_data_id( + os.path.join(self.data_root, "partnet_labels/partnet_train_val_test_split"), category, phase + ) self.shape_names = [] for name in shape_names: - path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', name) + path = os.path.join(self.data_root, "partnet_labels/partnet_pc_label", name) if os.path.exists(path): self.shape_names.append(name) @@ -129,12 +134,12 @@ class GANdatasetPartNet(Dataset): @staticmethod def load_point_cloud(path): pc = trimesh.load(path) - pc = pc.vertices / 2.0 # scale to unit sphere + pc = pc.vertices / 2.0 # scale to unit sphere return pc @staticmethod def read_point_cloud_part_label(path): - with open(path, 'r') as fp: + with open(path, "r") as fp: labels = fp.readlines() labels = np.array([int(x) for x in labels]) return labels @@ -156,26 +161,31 @@ class GANdatasetPartNet(Dataset): def __getitem__(self, index): raw_shape_name = self.shape_names[index] - raw_ply_path = os.path.join(self.data_root, 'partnet_data', raw_shape_name, 'point_sample/ply-10000.ply') + raw_ply_path = os.path.join(self.data_root, "partnet_data", raw_shape_name, "point_sample/ply-10000.ply") raw_pc = self.load_point_cloud(raw_ply_path) - raw_label_path = os.path.join(self.data_root, 'partnet_labels/partnet_pc_label', raw_shape_name, 'label-merge-level1-10000.txt') + raw_label_path = os.path.join( + self.data_root, "partnet_labels/partnet_pc_label", raw_shape_name, "label-merge-level1-10000.txt" + ) part_labels = self.read_point_cloud_part_label(raw_label_path) raw_pc, n_part_keep = self.random_rm_parts(raw_pc, part_labels) raw_pc = sample_point_cloud_by_n(raw_pc, self.raw_n_pts) raw_pc = torch.tensor(raw_pc, dtype=torch.float32).transpose(1, 0) real_shape_name = self.shape_names[index] - real_ply_path = os.path.join(self.data_root, 'partnet_data', real_shape_name, 'point_sample/ply-10000.ply') + real_ply_path = os.path.join(self.data_root, "partnet_data", real_shape_name, "point_sample/ply-10000.ply") real_pc = self.load_point_cloud(real_ply_path) real_pc = sample_point_cloud_by_n(real_pc, self.n_pts) real_pc = torch.tensor(real_pc, dtype=torch.float32).transpose(1, 0) - return {"raw": raw_pc, "real": real_pc, "raw_id": raw_shape_name, "real_id": real_shape_name, - 'n_part_keep': n_part_keep, 'idx': index} + return { + "raw": raw_pc, + "real": real_pc, + "raw_id": raw_shape_name, + "real_id": real_shape_name, + "n_part_keep": n_part_keep, + "idx": index, + } def __len__(self): return len(self.shape_names) - - - diff --git a/datasets/shapenet_data_pc.py b/datasets/shapenet_data_pc.py index 502d614..0e8ff8a 100644 --- a/datasets/shapenet_data_pc.py +++ b/datasets/shapenet_data_pc.py @@ -1,33 +1,67 @@ import os -import torch -import numpy as np -from torch.utils.data import Dataset -from torch.utils import data import random + import numpy as np -import torch.nn.functional as F +import torch +from torch.utils.data import Dataset # taken from https://github.com/optas/latent_3d_points/blob/8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py synsetid_to_cate = { - '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', - '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', - '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', - '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', - '02954340': 'cap', '02958343': 'car', '03001627': 'chair', - '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', - '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', - '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', - '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', - '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', - '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', - '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', - '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', - '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', - '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', - '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', - '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', - '04554684': 'washer', '02992529': 'cellphone', - '02843684': 'birdhouse', '02871439': 'bookshelf', + "02691156": "airplane", + "02773838": "bag", + "02801938": "basket", + "02808440": "bathtub", + "02818832": "bed", + "02828884": "bench", + "02876657": "bottle", + "02880940": "bowl", + "02924116": "bus", + "02933112": "cabinet", + "02747177": "can", + "02942699": "camera", + "02954340": "cap", + "02958343": "car", + "03001627": "chair", + "03046257": "clock", + "03207941": "dishwasher", + "03211117": "monitor", + "04379243": "table", + "04401088": "telephone", + "02946921": "tin_can", + "04460130": "tower", + "04468005": "train", + "03085013": "keyboard", + "03261776": "earphone", + "03325088": "faucet", + "03337140": "file", + "03467517": "guitar", + "03513137": "helmet", + "03593526": "jar", + "03624134": "knife", + "03636649": "lamp", + "03642806": "laptop", + "03691459": "speaker", + "03710193": "mailbox", + "03759954": "microphone", + "03761084": "microwave", + "03790512": "motorcycle", + "03797390": "mug", + "03928116": "piano", + "03938244": "pillow", + "03948459": "pistol", + "03991062": "pot", + "04004475": "printer", + "04074963": "remote_control", + "04090263": "rifle", + "04099429": "rocket", + "04225987": "skateboard", + "04256520": "sofa", + "04330267": "stove", + "04530566": "vessel", + "04554684": "washer", + "02992529": "cellphone", + "02843684": "birdhouse", + "02871439": "bookshelf", # '02858304': 'boat', no boat in our dataset, merged into vessels # '02834778': 'bicycle', not in our taxonomy } @@ -35,13 +69,23 @@ cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} class Uniform15KPC(Dataset): - def __init__(self, root_dir, subdirs, tr_sample_size=10000, - te_sample_size=10000, split='train', scale=1., - normalize_per_shape=False, box_per_shape=False, - random_subsample=False, - normalize_std_per_axis=False, - all_points_mean=None, all_points_std=None, - input_dim=3, use_mask=False): + def __init__( + self, + root_dir, + subdirs, + tr_sample_size=10000, + te_sample_size=10000, + split="train", + scale=1.0, + normalize_per_shape=False, + box_per_shape=False, + random_subsample=False, + normalize_std_per_axis=False, + all_points_mean=None, + all_points_std=None, + input_dim=3, + use_mask=False, + ): self.root_dir = root_dir self.split = split self.in_tr_sample_size = tr_sample_size @@ -67,9 +111,9 @@ class Uniform15KPC(Dataset): all_mids = [] for x in os.listdir(sub_path): - if not x.endswith('.npy'): + if not x.endswith(".npy"): continue - all_mids.append(os.path.join(self.split, x[:-len('.npy')])) + all_mids.append(os.path.join(self.split, x[: -len(".npy")])) # NOTE: [mid] contains the split: i.e. "train/" or "val/" or "test/" for mid in all_mids: @@ -111,7 +155,9 @@ class Uniform15KPC(Dataset): B, N = self.all_points.shape[:2] self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim) - self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim) + self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min( + axis=1 + ).reshape(B, 1, input_dim) else: # normalize across the dataset self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim) @@ -129,8 +175,7 @@ class Uniform15KPC(Dataset): self.tr_sample_size = min(10000, tr_sample_size) self.te_sample_size = min(5000, te_sample_size) print("Total number of data:%d" % len(self.train_points)) - print("Min number of points: (train)%d (test)%d" - % (self.tr_sample_size, self.te_sample_size)) + print("Min number of points: (train)%d (test)%d" % (self.tr_sample_size, self.te_sample_size)) assert self.scale == 1, "Scale (!= 1) is deprecated" def get_pc_stats(self, idx): @@ -139,7 +184,6 @@ class Uniform15KPC(Dataset): s = self.all_points_std[idx].reshape(1, -1) return m, s - return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1) def renormalize(self, mean, std): @@ -173,11 +217,14 @@ class Uniform15KPC(Dataset): sid, mid = self.all_cate_mids[idx] out = { - 'idx': idx, - 'train_points': tr_out, - 'test_points': te_out, - 'mean': m, 'std': s, 'cate_idx': cate_idx, - 'sid': sid, 'mid': mid + "idx": idx, + "train_points": tr_out, + "test_points": te_out, + "mean": m, + "std": s, + "cate_idx": cate_idx, + "sid": sid, + "mid": mid, } if self.use_mask: @@ -192,26 +239,35 @@ class Uniform15KPC(Dataset): # out['train_points_masked'] = masked # out['train_masks'] = tr_mask tr_mask = self.mask_transform(tr_out) - out['train_masks'] = tr_mask + out["train_masks"] = tr_mask return out class ShapeNet15kPointClouds(Uniform15KPC): - def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k", - categories=['airplane'], tr_sample_size=10000, te_sample_size=2048, - split='train', scale=1., normalize_per_shape=False, - normalize_std_per_axis=False, box_per_shape=False, - random_subsample=False, - all_points_mean=None, all_points_std=None, - use_mask=False): + def __init__( + self, + root_dir="data/ShapeNetCore.v2.PC15k", + categories=["airplane"], + tr_sample_size=10000, + te_sample_size=2048, + split="train", + scale=1.0, + normalize_per_shape=False, + normalize_std_per_axis=False, + box_per_shape=False, + random_subsample=False, + all_points_mean=None, + all_points_std=None, + use_mask=False, + ): self.root_dir = root_dir self.split = split - assert self.split in ['train', 'test', 'val'] + assert self.split in ["train", "test", "val"] self.tr_sample_size = tr_sample_size self.te_sample_size = te_sample_size self.cates = categories - if 'all' in categories: + if "all" in categories: self.synset_ids = list(cate_to_synsetid.values()) else: self.synset_ids = [cate_to_synsetid[c] for c in self.cates] @@ -221,19 +277,21 @@ class ShapeNet15kPointClouds(Uniform15KPC): self.display_axis_order = [0, 2, 1] super(ShapeNet15kPointClouds, self).__init__( - root_dir, self.synset_ids, + root_dir, + self.synset_ids, tr_sample_size=tr_sample_size, te_sample_size=te_sample_size, - split=split, scale=scale, - normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape, + split=split, + scale=scale, + normalize_per_shape=normalize_per_shape, + box_per_shape=box_per_shape, normalize_std_per_axis=normalize_std_per_axis, random_subsample=random_subsample, - all_points_mean=all_points_mean, all_points_std=all_points_std, - input_dim=3, use_mask=use_mask) - - + all_points_mean=all_points_mean, + all_points_std=all_points_std, + input_dim=3, + use_mask=use_mask, + ) #################################################################################### - - diff --git a/datasets/shapenet_data_sv.py b/datasets/shapenet_data_sv.py index 63d5a4c..67d0641 100644 --- a/datasets/shapenet_data_sv.py +++ b/datasets/shapenet_data_sv.py @@ -1,34 +1,70 @@ +import hashlib +import os import warnings +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch from torch.utils.data import Dataset from tqdm import tqdm -from pathlib import Path -import os -import numpy as np - -import hashlib -import torch -import matplotlib.pyplot as plt synset_to_label = { - '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', - '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', - '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', - '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', - '02954340': 'cap', '02958343': 'car', '03001627': 'chair', - '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', - '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', - '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', - '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', - '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', - '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', - '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', - '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', - '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', - '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', - '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', - '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', - '04554684': 'washer', '02992529': 'cellphone', - '02843684': 'birdhouse', '02871439': 'bookshelf', + "02691156": "airplane", + "02773838": "bag", + "02801938": "basket", + "02808440": "bathtub", + "02818832": "bed", + "02828884": "bench", + "02876657": "bottle", + "02880940": "bowl", + "02924116": "bus", + "02933112": "cabinet", + "02747177": "can", + "02942699": "camera", + "02954340": "cap", + "02958343": "car", + "03001627": "chair", + "03046257": "clock", + "03207941": "dishwasher", + "03211117": "monitor", + "04379243": "table", + "04401088": "telephone", + "02946921": "tin_can", + "04460130": "tower", + "04468005": "train", + "03085013": "keyboard", + "03261776": "earphone", + "03325088": "faucet", + "03337140": "file", + "03467517": "guitar", + "03513137": "helmet", + "03593526": "jar", + "03624134": "knife", + "03636649": "lamp", + "03642806": "laptop", + "03691459": "speaker", + "03710193": "mailbox", + "03759954": "microphone", + "03761084": "microwave", + "03790512": "motorcycle", + "03797390": "mug", + "03928116": "piano", + "03938244": "pillow", + "03948459": "pistol", + "03991062": "pot", + "04004475": "printer", + "04074963": "remote_control", + "04090263": "rifle", + "04099429": "rocket", + "04225987": "skateboard", + "04256520": "sofa", + "04330267": "stove", + "04530566": "vessel", + "04554684": "washer", + "02992529": "cellphone", + "02843684": "birdhouse", + "02871439": "bookshelf", # '02858304': 'boat', no boat in our dataset, merged into vessels # '02834778': 'bicycle', not in our taxonomy } @@ -36,30 +72,44 @@ synset_to_label = { # Label to Synset mapping (for ShapeNet core classes) label_to_synset = {v: k for k, v in synset_to_label.items()} + def _convert_categories(categories): - assert categories is not None, 'List of categories cannot be empty!' - if not (c in synset_to_label.keys() + label_to_synset.keys() - for c in categories): - warnings.warn('Some or all of the categories requested are not part of \ - ShapeNetCore. Data loading may fail if these categories are not avaliable.') - synsets = [label_to_synset[c] if c in label_to_synset.keys() - else c for c in categories] + assert categories is not None, "List of categories cannot be empty!" + if not (c in synset_to_label.keys() + label_to_synset.keys() for c in categories): + warnings.warn( + "Some or all of the categories requested are not part of \ + ShapeNetCore. Data loading may fail if these categories are not avaliable." + ) + synsets = [label_to_synset[c] if c in label_to_synset.keys() else c for c in categories] return synsets class ShapeNet_Multiview_Points(Dataset): - def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val', - npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False): + def __init__( + self, + root_pc: str, + root_views: str, + cache: str, + categories: list = ["chair"], + split: str = "val", + npoints=2048, + sv_samples=800, + all_points_mean=None, + all_points_std=None, + get_image=False, + ): self.root = Path(root_views) self.split = split self.get_image = get_image params = { - 'cat': categories, - 'npoints': npoints, - 'sv_samples': sv_samples, + "cat": categories, + "npoints": npoints, + "sv_samples": sv_samples, } params = tuple(sorted(pair for pair in params.items())) - self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest()) + self.cache_dir = Path(cache) / "svpoints/{}/{}".format( + "_".join(categories), hashlib.md5(bytes(repr(params), "utf-8")).hexdigest() + ) self.cache_dir.mkdir(parents=True, exist_ok=True) self.paths = [] @@ -74,13 +124,12 @@ class ShapeNet_Multiview_Points(Dataset): # loops through desired classes for i in range(len(self.synsets)): - syn = self.synsets[i] class_target = self.root / syn if not class_target.exists(): - raise ValueError('Class {0} ({1}) was not found at location {2}.'.format( - syn, self.labels[i], str(class_target))) - + raise ValueError( + "Class {0} ({1}) was not found at location {2}.".format(syn, self.labels[i], str(class_target)) + ) sub_path_pc = os.path.join(root_pc, syn, split) if not os.path.isdir(sub_path_pc): @@ -90,30 +139,30 @@ class ShapeNet_Multiview_Points(Dataset): self.all_mids = [] self.imgs = [] for x in os.listdir(sub_path_pc): - if not x.endswith('.npy'): + if not x.endswith(".npy"): continue - self.all_mids.append(os.path.join(split, x[:-len('.npy')])) + self.all_mids.append(os.path.join(split, x[: -len(".npy")])) for mid in tqdm(self.all_mids): # obj_fname = os.path.join(sub_path, x) obj_fname = os.path.join(root_pc, syn, mid + ".npy") - cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz')) + cams_pths = list((self.root / syn / mid.split("/")[-1]).glob("*_cam_params.npz")) if len(cams_pths) < 20: continue point_cloud = np.load(obj_fname) sv_points_group = [] img_path_group = [] - (self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True) + (self.cache_dir / (mid.split("/")[-1])).mkdir(parents=True, exist_ok=True) success = True for i, cp in enumerate(cams_pths): cp = str(cp) - vp = cp.split('cam_params')[0] + 'depth.png' - depth_minmax_pth = cp.split('_cam_params')[0] + '.npy' - cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) ) + vp = cp.split("cam_params")[0] + "depth.png" + depth_minmax_pth = cp.split("_cam_params")[0] + ".npy" + cache_pth = str(self.cache_dir / mid.split("/")[-1] / os.path.basename(depth_minmax_pth)) cam_params = np.load(cp) - extr = cam_params['extr'] - intr = cam_params['intr'] + extr = cam_params["extr"] + intr = cam_params["intr"] self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr) @@ -125,7 +174,7 @@ class ShapeNet_Multiview_Points(Dataset): sv_points_group.append(sv_point_cloud) except Exception as e: print(e) - success=False + success = False break if not success: continue @@ -144,64 +193,66 @@ class ShapeNet_Multiview_Points(Dataset): self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1) self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std - self.train_points = self.all_points[:,:10000] - self.test_points = self.all_points[:,10000:] + self.train_points = self.all_points[:, :10000] + self.test_points = self.all_points[:, 10000:] self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std def get_pc_stats(self, idx): - - return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1) + return self.all_points_mean.reshape(1, 1, -1), self.all_points_std.reshape(1, 1, -1) def __len__(self): - """Returns the length of the dataset. """ + """Returns the length of the dataset.""" return len(self.all_points) def __getitem__(self, index): - - tr_out = self.train_points[index] tr_idxs = np.random.choice(tr_out.shape[0], self.npoints) tr_out = tr_out[tr_idxs, :] - gt_points = self.test_points[index][:self.npoints] + gt_points = self.test_points[index][: self.npoints] m, s = self.get_pc_stats(index) sv_points = self.all_points_sv[index] - idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False) + idxs = np.arange(0, sv_points.shape[-2])[ + : self.sv_samples + ] # np.random.choice(sv_points.shape[0], 500, replace=False) - data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(), - torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1) + data = torch.cat( + [ + torch.from_numpy(sv_points[:, idxs]).float(), + torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2]), + ], + dim=1, + ) masks = torch.zeros_like(data) - masks[:,:idxs.shape[0]] = 1 + masks[:, : idxs.shape[0]] = 1 - res = {'train_points': torch.from_numpy(tr_out).float(), - 'test_points': torch.from_numpy(gt_points).float(), - 'sv_points': data, - 'masks': masks, - 'std': s, 'mean': m, - 'idx': index, - 'name':self.all_mids[index] - } - - if self.split != 'train' and self.get_image: + res = { + "train_points": torch.from_numpy(tr_out).float(), + "test_points": torch.from_numpy(gt_points).float(), + "sv_points": data, + "masks": masks, + "std": s, + "mean": m, + "idx": index, + "name": self.all_mids[index], + } + if self.split != "train" and self.get_image: img_lst = [] for n in range(self.all_points_sv.shape[1]): - - img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3] + img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2, 0, 1)[:3] img_lst.append(img) img = torch.stack(img_lst, dim=0) - res['image'] = img + res["image"] = img return res - - def _render(self, cache_path, depth_pth, depth_minmax_pth): # if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path): # @@ -210,11 +261,9 @@ class ShapeNet_Multiview_Points(Dataset): if os.path.exists(cache_path): data = np.load(cache_path) else: - data, depth = self.transform(depth_pth, depth_minmax_pth) - assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0]) + assert data.shape[0] > 600, "Only {} points found".format(data.shape[0]) data = data[np.random.choice(data.shape[0], 600, replace=False)] np.save(cache_path, data) return data - diff --git a/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py b/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py index f92e6f1..18d4667 100644 --- a/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py +++ b/metrics/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py @@ -1,25 +1,32 @@ -from torch import nn -from torch.autograd import Function -import torch import importlib import os + +import torch +from torch import nn +from torch.autograd import Function + chamfer_found = importlib.find_loader("chamfer_2D") is not None if not chamfer_found: ## Cool trick from https://github.com/chrdiller print("Jitting Chamfer 2D") from torch.utils.cpp_extension import load - chamfer_2D = load(name="chamfer_2D", - sources=[ - "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), - "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), - ]) + + chamfer_2D = load( + name="chamfer_2D", + sources=[ + "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer2D.cu"]), + ], + ) print("Loaded JIT 2D CUDA chamfer distance") else: import chamfer_2D + print("Loaded compiled 2D CUDA chamfer distance") + # Chamfer's distance module @thibaultgroueix # GPU tensors only class chamfer_2DFunction(Function): @@ -57,9 +64,7 @@ class chamfer_2DFunction(Function): gradxyz1 = gradxyz1.to(device) gradxyz2 = gradxyz2.to(device) - chamfer_2D.backward( - xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 - ) + chamfer_2D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) return gradxyz1, gradxyz2 diff --git a/metrics/ChamferDistancePytorch/chamfer2D/setup.py b/metrics/ChamferDistancePytorch/chamfer2D/setup.py index 1b729b3..6ead863 100644 --- a/metrics/ChamferDistancePytorch/chamfer2D/setup.py +++ b/metrics/ChamferDistancePytorch/chamfer2D/setup.py @@ -2,15 +2,16 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( - name='chamfer_2D', + name="chamfer_2D", ext_modules=[ - CUDAExtension('chamfer_2D', [ - "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), - "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), - ]), + CUDAExtension( + "chamfer_2D", + [ + "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]), + "/".join(__file__.split("/")[:-1] + ["chamfer2D.cu"]), + ], + ), ], - - extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], - cmdclass={ - 'build_ext': BuildExtension - }) \ No newline at end of file + extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py b/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py index 5811e46..f48e71c 100644 --- a/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py +++ b/metrics/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py @@ -1,25 +1,30 @@ -from torch import nn -from torch.autograd import Function -import torch import importlib import os + +import torch +from torch import nn +from torch.autograd import Function + chamfer_found = importlib.find_loader("chamfer_3D") is not None if not chamfer_found: ## Cool trick from https://github.com/chrdiller print("Jitting Chamfer 3D") from torch.utils.cpp_extension import load - chamfer_3D = load(name="chamfer_3D", - sources=[ - "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), - "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), - ], - extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'],) + chamfer_3D = load( + name="chamfer_3D", + sources=[ + "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer3D.cu"]), + ], + extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"], + ) print("Loaded JIT 3D CUDA chamfer distance") else: import chamfer_3D + print("Loaded compiled 3D CUDA chamfer distance") @@ -60,9 +65,7 @@ class chamfer_3DFunction(Function): gradxyz1 = gradxyz1.to(device) gradxyz2 = gradxyz2.to(device) - chamfer_3D.backward( - xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 - ) + chamfer_3D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) return gradxyz1, gradxyz2 @@ -74,4 +77,3 @@ class chamfer_3DDist(nn.Module): input1 = input1.contiguous() input2 = input2.contiguous() return chamfer_3DFunction.apply(input1, input2) - diff --git a/metrics/ChamferDistancePytorch/chamfer3D/setup.py b/metrics/ChamferDistancePytorch/chamfer3D/setup.py index 7d3723e..bdf061e 100644 --- a/metrics/ChamferDistancePytorch/chamfer3D/setup.py +++ b/metrics/ChamferDistancePytorch/chamfer3D/setup.py @@ -2,15 +2,16 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( - name='chamfer_3D', + name="chamfer_3D", ext_modules=[ - CUDAExtension('chamfer_3D', [ - "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), - "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), - ]), + CUDAExtension( + "chamfer_3D", + [ + "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]), + "/".join(__file__.split("/")[:-1] + ["chamfer3D.cu"]), + ], + ), ], - - extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], - cmdclass={ - 'build_ext': BuildExtension - }) \ No newline at end of file + extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py b/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py index 3730a1f..042372c 100644 --- a/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py +++ b/metrics/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py @@ -1,24 +1,29 @@ -from torch import nn -from torch.autograd import Function -import torch import importlib import os +import torch +from torch import nn +from torch.autograd import Function + chamfer_found = importlib.find_loader("chamfer_5D") is not None if not chamfer_found: ## Cool trick from https://github.com/chrdiller print("Jitting Chamfer 5D") from torch.utils.cpp_extension import load - chamfer_5D = load(name="chamfer_5D", - sources=[ - "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), - "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), - ]) + + chamfer_5D = load( + name="chamfer_5D", + sources=[ + "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]), + "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer5D.cu"]), + ], + ) print("Loaded JIT 5D CUDA chamfer distance") else: import chamfer_5D + print("Loaded compiled 5D CUDA chamfer distance") @@ -59,9 +64,7 @@ class chamfer_5DFunction(Function): gradxyz1 = gradxyz1.to(device) gradxyz2 = gradxyz2.to(device) - chamfer_5D.backward( - xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 - ) + chamfer_5D.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) return gradxyz1, gradxyz2 diff --git a/metrics/ChamferDistancePytorch/chamfer5D/setup.py b/metrics/ChamferDistancePytorch/chamfer5D/setup.py index 297aa33..77a8694 100644 --- a/metrics/ChamferDistancePytorch/chamfer5D/setup.py +++ b/metrics/ChamferDistancePytorch/chamfer5D/setup.py @@ -2,15 +2,16 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( - name='chamfer_5D', + name="chamfer_5D", ext_modules=[ - CUDAExtension('chamfer_5D', [ - "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), - "/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), - ]), + CUDAExtension( + "chamfer_5D", + [ + "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]), + "/".join(__file__.split("/")[:-1] + ["chamfer5D.cu"]), + ], + ), ], - - extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], - cmdclass={ - 'build_ext': BuildExtension - }) \ No newline at end of file + extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/metrics/ChamferDistancePytorch/chamfer_python.py b/metrics/ChamferDistancePytorch/chamfer_python.py index ce0aeaa..a6bfc18 100644 --- a/metrics/ChamferDistancePytorch/chamfer_python.py +++ b/metrics/ChamferDistancePytorch/chamfer_python.py @@ -33,8 +33,7 @@ def distChamfer(a, b): xx = torch.pow(x, 2).sum(2) yy = torch.pow(y, 2).sum(2) zz = torch.bmm(x, y.transpose(2, 1)) - rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx - ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy + rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx + ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy P = rx.transpose(2, 1) + ry - 2 * zz return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() - diff --git a/metrics/ChamferDistancePytorch/fscore.py b/metrics/ChamferDistancePytorch/fscore.py index 265378b..3718187 100644 --- a/metrics/ChamferDistancePytorch/fscore.py +++ b/metrics/ChamferDistancePytorch/fscore.py @@ -1,5 +1,6 @@ import torch + def fscore(dist1, dist2, threshold=0.001): """ Calculates the F-score between two point clouds with the corresponding threshold value. @@ -14,4 +15,3 @@ def fscore(dist1, dist2, threshold=0.001): fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) fscore[torch.isnan(fscore)] = 0 return fscore, precision_1, precision_2 - diff --git a/metrics/ChamferDistancePytorch/unit_test.py b/metrics/ChamferDistancePytorch/unit_test.py index 13af6a3..d3b2eb8 100644 --- a/metrics/ChamferDistancePytorch/unit_test.py +++ b/metrics/ChamferDistancePytorch/unit_test.py @@ -1,20 +1,23 @@ -import torch, time +import time + import chamfer2D.dist_chamfer_2D import chamfer3D.dist_chamfer_3D import chamfer5D.dist_chamfer_5D import chamfer_python +import torch cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() -from torch.autograd import Variable from fscore import fscore +from torch.autograd import Variable + def test_chamfer(distChamfer, dim): points1 = torch.rand(4, 100, dim).cuda() points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() - dist1, dist2, idx1, idx2= distChamfer(points1, points2) + dist1, dist2, idx1, idx2 = distChamfer(points1, points2) loss = torch.sum(dist1) loss.backward() @@ -29,9 +32,9 @@ def test_chamfer(distChamfer, dim): xd1 = idx1 - myidx1 xd2 = idx2 - myidx2 assert ( - torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 + torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 ), "chamfer cuda and chamfer normal are not giving the same results" - print(f"fscore :", fscore(dist1, dist2)) + print("fscore :", fscore(dist1, dist2)) print("Unit test passed") @@ -49,7 +52,6 @@ def timings(distChamfer, dim): loss.backward() print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") - print("Timings : Start Pythonic version") start = time.time() for i in range(num_it): @@ -61,9 +63,8 @@ def timings(distChamfer, dim): print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") - -dims = [2,3,5] -for i,cham in enumerate([cham2D, cham3D, cham5D]): +dims = [2, 3, 5] +for i, cham in enumerate([cham2D, cham3D, cham5D]): print(f"testing Chamfer {dims[i]}D") test_chamfer(cham, dims[i]) timings(cham, dims[i]) diff --git a/metrics/PyTorchEMD/emd.py b/metrics/PyTorchEMD/emd.py index b0a01ce..0f6460a 100644 --- a/metrics/PyTorchEMD/emd.py +++ b/metrics/PyTorchEMD/emd.py @@ -1,5 +1,5 @@ -import torch import emd_cuda +import torch class EarthMoverDistanceFunction(torch.autograd.Function): @@ -44,4 +44,3 @@ def earth_mover_distance(xyz1, xyz2, transpose=True): cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) cost = cost / xyz1.shape[1] return cost - diff --git a/metrics/PyTorchEMD/setup.py b/metrics/PyTorchEMD/setup.py index f648c3e..0c9be58 100644 --- a/metrics/PyTorchEMD/setup.py +++ b/metrics/PyTorchEMD/setup.py @@ -9,19 +9,17 @@ Notes: from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension - setup( - name='emd_ext', + name="emd_ext", ext_modules=[ CUDAExtension( - name='emd_cuda', + name="emd_cuda", sources=[ - 'cuda/emd.cpp', - 'cuda/emd_kernel.cu', + "cuda/emd.cpp", + "cuda/emd_kernel.cu", ], - extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} + extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, ), ], - cmdclass={ - 'build_ext': BuildExtension - }) + cmdclass={"build_ext": BuildExtension}, +) diff --git a/metrics/PyTorchEMD/test_emd_loss.py b/metrics/PyTorchEMD/test_emd_loss.py index 66aa33c..417ecc1 100644 --- a/metrics/PyTorchEMD/test_emd_loss.py +++ b/metrics/PyTorchEMD/test_emd_loss.py @@ -1,6 +1,5 @@ -import torch import numpy as np -import time +import torch from emd import earth_mover_distance # gt @@ -13,10 +12,12 @@ print(p2) p1.requires_grad = True p2.requires_grad = True -gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \ - (((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \ - (((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3 -print('gt_dist: ', gt_dist) +gt_dist = ( + (((p1[0, 0] - p2[0, 1]) ** 2).sum() + ((p1[0, 1] - p2[0, 0]) ** 2).sum()) / 2 + + (((p1[1, 0] - p2[1, 1]) ** 2).sum() + ((p1[1, 1] - p2[1, 0]) ** 2).sum()) * 2 + + (((p1[2, 0] - p2[2, 1]) ** 2).sum() + ((p1[2, 1] - p2[2, 0]) ** 2).sum()) / 3 +) +print("gt_dist: ", gt_dist) gt_dist.backward() print(p1.grad) @@ -41,4 +42,3 @@ print(loss) loss.backward() print(p1.grad) print(p2.grad) - diff --git a/metrics/evaluation_metrics.py b/metrics/evaluation_metrics.py index c0cbf62..0f250e9 100644 --- a/metrics/evaluation_metrics.py +++ b/metrics/evaluation_metrics.py @@ -1,17 +1,19 @@ -import torch -import numpy as np import warnings + +import numpy as np +import torch +from numpy.linalg import norm from scipy.stats import entropy from sklearn.neighbors import NearestNeighbors -from numpy.linalg import norm - -from metrics.PyTorchEMD.emd import earth_mover_distance as EMD -from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist -from metrics.ChamferDistancePytorch.fscore import fscore from tqdm import tqdm +from metrics.ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import chamfer_3DDist +from metrics.ChamferDistancePytorch.fscore import fscore +from metrics.PyTorchEMD.emd import earth_mover_distance as EMD + cham3D = chamfer_3DDist() + # Borrow from https://github.com/ThibaultGROUEIX/AtlasNet def distChamfer(a, b): x, y = a, b @@ -22,11 +24,11 @@ def distChamfer(a, b): diag_ind = torch.arange(0, num_points).to(a).long() rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) - P = (rx.transpose(2, 1) + ry - 2 * zz) + P = rx.transpose(2, 1) + ry - 2 * zz return P.min(1)[0], P.min(2)[0] -def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): +def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): N_sample = sample_pcs.shape[0] N_ref = ref_pcs.shape[0] assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) @@ -56,13 +58,10 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): cd = torch.cat(cd_lst) emd = torch.cat(emd_lst) fs_lst = torch.cat(fs_lst).mean() - results = { - 'MMD-CD': cd, - 'MMD-EMD': emd, - 'fscore': fs_lst - } + results = {"MMD-CD": cd, "MMD-EMD": emd, "fscore": fs_lst} return results + def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True): N_sample = sample_pcs.shape[0] N_ref = ref_pcs.shape[0] @@ -107,7 +106,7 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False): M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0) if sqrt: M = M.abs().sqrt() - INFINITY = float('inf') + INFINITY = float("inf") val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False) count = torch.zeros(n0 + n1).to(Mxx) @@ -116,19 +115,21 @@ def knn(Mxx, Mxy, Myy, k, sqrt=False): pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() s = { - 'tp': (pred * label).sum(), - 'fp': (pred * (1 - label)).sum(), - 'fn': ((1 - pred) * label).sum(), - 'tn': ((1 - pred) * (1 - label)).sum(), + "tp": (pred * label).sum(), + "fp": (pred * (1 - label)).sum(), + "fn": ((1 - pred) * label).sum(), + "tn": ((1 - pred) * (1 - label)).sum(), } - s.update({ - 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), - 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), - 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), - 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), - 'acc': torch.eq(label, pred).float().mean(), - }) + s.update( + { + "precision": s["tp"] / (s["tp"] + s["fp"] + 1e-10), + "recall": s["tp"] / (s["tp"] + s["fn"] + 1e-10), + "acc_t": s["tp"] / (s["tp"] + s["fn"] + 1e-10), + "acc_f": s["tn"] / (s["tn"] + s["fp"] + 1e-10), + "acc": torch.eq(label, pred).float().mean(), + } + ) return s @@ -141,9 +142,9 @@ def lgan_mmd_cov(all_dist): cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) cov = torch.tensor(cov).to(all_dist) return { - 'lgan_mmd': mmd, - 'lgan_cov': cov, - 'lgan_mmd_smp': mmd_smp, + "lgan_mmd": mmd, + "lgan_cov": cov, + "lgan_mmd_smp": mmd_smp, } @@ -153,27 +154,19 @@ def compute_all_metrics(sample_pcs, ref_pcs, batch_size): M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) res_cd = lgan_mmd_cov(M_rs_cd.t()) - results.update({ - "%s-CD" % k: v for k, v in res_cd.items() - }) + results.update({"%s-CD" % k: v for k, v in res_cd.items()}) res_emd = lgan_mmd_cov(M_rs_emd.t()) - results.update({ - "%s-EMD" % k: v for k, v in res_emd.items() - }) + results.update({"%s-EMD" % k: v for k, v in res_emd.items()}) M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) # 1-NN results one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) - results.update({ - "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k - }) + results.update({"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if "acc" in k}) one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) - results.update({ - "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k - }) + results.update({"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if "acc" in k}) return results @@ -227,11 +220,11 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose bound = 0.5 + epsilon if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: if verbose: - warnings.warn('Point-clouds are not in unit cube.') + warnings.warn("Point-clouds are not in unit cube.") - if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: + if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound: if verbose: - warnings.warn('Point-clouds are not in unit sphere.') + warnings.warn("Point-clouds are not in unit sphere.") grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) grid_coordinates = grid_coordinates.reshape(-1, 3) @@ -260,9 +253,9 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose def jensen_shannon_divergence(P, Q): if np.any(P < 0) or np.any(Q < 0): - raise ValueError('Negative values.') + raise ValueError("Negative values.") if len(P) != len(Q): - raise ValueError('Non equal size.') + raise ValueError("Non equal size.") P_ = P / np.sum(P) # Ensure probabilities. Q_ = Q / np.sum(Q) @@ -275,7 +268,7 @@ def jensen_shannon_divergence(P, Q): res2 = _jsdiv(P_, Q_) if not np.allclose(res, res2, atol=10e-5, rtol=0): - warnings.warn('Numerical values of two JSD methods don\'t agree.') + warnings.warn("Numerical values of two JSD methods don't agree.") return res @@ -312,11 +305,9 @@ if __name__ == "__main__": r_dist = min_r.mean().cpu().detach().item() print(l_dist, r_dist) - emd_batch = EMD(x.cuda(), y.cuda(), False) print(emd_batch.shape) print(emd_batch.mean().detach().item()) jsd = jsd_between_point_cloud_sets(x.numpy(), y.numpy()) print(jsd) - diff --git a/model/pvcnn_completion.py b/model/pvcnn_completion.py index db48b6b..7603f86 100644 --- a/model/pvcnn_completion.py +++ b/model/pvcnn_completion.py @@ -1,13 +1,14 @@ import functools -import torch.nn as nn -import torch import numpy as np -from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish +import torch +import torch.nn as nn + +from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish def _linear_gn_relu(in_channels, out_channels): - return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish()) def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): @@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) -def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, - width_multiplier=1, voxel_resolution_multiplier=1): +def create_pointnet_components( + blocks, + in_channels, + embed_dim, + with_se=False, + normalize=True, + eps=0, + width_multiplier=1, + voxel_resolution_multiplier=1, +): r, vr = width_multiplier, voxel_resolution_multiplier layers, concat_channels = [], 0 @@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no if voxel_resolution is None: block = SharedMLP else: - block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, - with_se=with_se, normalize=normalize, eps=eps) + block = functools.partial( + PVConv, + kernel_size=3, + resolution=int(vr * voxel_resolution), + attention=attention, + with_se=with_se, + normalize=normalize, + eps=eps, + ) if c == 0: layers.append(block(in_channels, out_channels)) else: - layers.append(block(in_channels+embed_dim, out_channels)) + layers.append(block(in_channels + embed_dim, out_channels)) in_channels = out_channels concat_channels += out_channels c += 1 return layers, in_channels, concat_channels -def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False, - dropout=0.1, with_se=False, normalize=True, eps=0, - width_multiplier=1, voxel_resolution_multiplier=1): +def create_pointnet2_sa_components( + sa_blocks, + extra_feature_channels, + embed_dim=64, + use_att=False, + dropout=0.1, + with_se=False, + normalize=True, + eps=0, + width_multiplier=1, + voxel_resolution_multiplier=1, +): r, vr = width_multiplier, voxel_resolution_multiplier in_channels = extra_feature_channels + 3 @@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim= out_channels, num_blocks, voxel_resolution = conv_configs out_channels = int(r * out_channels) for p in range(num_blocks): - attention = (c+1) % 2 == 0 and c > 0 and use_att and p == 0 + attention = (c + 1) % 2 == 0 and c > 0 and use_att and p == 0 if voxel_resolution is None: block = SharedMLP else: - block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, - dropout=dropout, - with_se=with_se and not attention, with_se_relu=True, - normalize=normalize, eps=eps) + block = functools.partial( + PVConv, + kernel_size=3, + resolution=int(vr * voxel_resolution), + attention=attention, + dropout=dropout, + with_se=with_se and not attention, + with_se_relu=True, + normalize=normalize, + eps=eps, + ) if c == 0: sa_blocks.append(block(in_channels, out_channels)) - elif k ==0: - sa_blocks.append(block(in_channels+embed_dim, out_channels)) + elif k == 0: + sa_blocks.append(block(in_channels + embed_dim, out_channels)) in_channels = out_channels k += 1 extra_feature_channels = in_channels @@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim= if num_centers is None: block = PointNetAModule else: - block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, - num_neighbors=num_neighbors) - sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, - include_coordinates=True)) + block = functools.partial( + PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors + ) + sa_blocks.append( + block( + in_channels=extra_feature_channels + (embed_dim if k == 0 else 0), + out_channels=out_channels, + include_coordinates=True, + ) + ) c += 1 in_channels = extra_feature_channels = sa_blocks[-1].out_channels if len(sa_blocks) == 1: @@ -127,10 +165,20 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim= return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers -def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_points, embed_dim=64, use_att=False, - dropout=0.1, - with_se=False, normalize=True, eps=0, - width_multiplier=1, voxel_resolution_multiplier=1): +def create_pointnet2_fp_modules( + fp_blocks, + in_channels, + sa_in_channels, + sv_points, + embed_dim=64, + use_att=False, + dropout=0.1, + with_se=False, + normalize=True, + eps=0, + width_multiplier=1, + voxel_resolution_multiplier=1, +): r, vr = width_multiplier, voxel_resolution_multiplier fp_layers = [] @@ -139,7 +187,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point fp_blocks = [] out_channels = tuple(int(r * oc) for oc in fp_configs) fp_blocks.append( - PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) + PointNetFPModule( + in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels + ) ) in_channels = out_channels[-1] @@ -151,9 +201,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point if voxel_resolution is None: block = SharedMLP else: - block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, - dropout=dropout, - with_se=with_se and not attention,with_se_relu=True, normalize=normalize, eps=eps) + block = functools.partial( + PVConv, + kernel_size=3, + resolution=int(vr * voxel_resolution), + attention=attention, + dropout=dropout, + with_se=with_se and not attention, + with_se_relu=True, + normalize=normalize, + eps=eps, + ) fp_blocks.append(block(in_channels, out_channels)) in_channels = out_channels @@ -168,9 +226,17 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, sv_point class PVCNN2Base(nn.Module): - - def __init__(self, num_classes, sv_points, embed_dim, use_att, dropout=0.1, - extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1): + def __init__( + self, + num_classes, + sv_points, + embed_dim, + use_att, + dropout=0.1, + extra_feature_channels=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + ): super().__init__() assert extra_feature_channels >= 0 self.embed_dim = embed_dim @@ -178,9 +244,14 @@ class PVCNN2Base(nn.Module): self.in_channels = extra_feature_channels + 3 sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( - sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim, - use_att=use_att, dropout=dropout, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + sa_blocks=self.sa_blocks, + extra_feature_channels=extra_feature_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) self.sa_layers = nn.ModuleList(sa_layers) @@ -189,16 +260,26 @@ class PVCNN2Base(nn.Module): # only use extra features in the last fp module sa_in_channels[0] = extra_feature_channels fp_layers, channels_fp_features = create_pointnet2_fp_modules( - fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels,sv_points=sv_points, - with_se=True, embed_dim=embed_dim, - use_att=use_att, dropout=dropout, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + fp_blocks=self.fp_blocks, + in_channels=channels_sa_features, + sa_in_channels=sa_in_channels, + sv_points=sv_points, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) self.fp_layers = nn.ModuleList(fp_layers) - - layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, 0.5, num_classes], - classifier=True, dim=2, width_multiplier=width_multiplier) + layers, _ = create_mlp_components( + in_channels=channels_fp_features, + out_channels=[128, 0.5, num_classes], + classifier=True, + dim=2, + width_multiplier=width_multiplier, + ) self.classifier = nn.Sequential(*layers) self.embedf = nn.Sequential( @@ -223,31 +304,30 @@ class PVCNN2Base(nn.Module): return emb def forward(self, inputs, t): - - temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1]) + temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1]) # inputs : [B, in_channels + S, N] coords, features = inputs[:, :3, :].contiguous(), inputs coords_list, in_features_list = [], [] - for i, sa_blocks in enumerate(self.sa_layers): + for i, sa_blocks in enumerate(self.sa_layers): in_features_list.append(features) coords_list.append(coords) if i == 0: - features, coords, temb = sa_blocks ((features, coords, temb)) + features, coords, temb = sa_blocks((features, coords, temb)) else: - features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb)) + features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb)) in_features_list[0] = inputs[:, 3:, :].contiguous() if self.global_att is not None: features = self.global_att(features) - for fp_idx, fp_blocks in enumerate(self.fp_layers): + for fp_idx, fp_blocks in enumerate(self.fp_layers): jump_coords = coords_list[-1 - fp_idx] - fump_feats = in_features_list[-1-fp_idx] + fump_feats = in_features_list[-1 - fp_idx] # if fp_idx == len(self.fp_layers) - 1: # jump_coords = jump_coords[:,:,self.sv_points:] # fump_feats = fump_feats[:,:,self.sv_points:] - features, coords, temb = fp_blocks((jump_coords, coords, torch.cat([features,temb],dim=1), fump_feats, temb)) + features, coords, temb = fp_blocks( + (jump_coords, coords, torch.cat([features, temb], dim=1), fump_feats, temb) + ) return self.classifier(features) - - diff --git a/model/pvcnn_generation.py b/model/pvcnn_generation.py index 3926b9e..99720d7 100644 --- a/model/pvcnn_generation.py +++ b/model/pvcnn_generation.py @@ -1,13 +1,14 @@ import functools -import torch.nn as nn -import torch import numpy as np -from modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Attention, Swish +import torch +import torch.nn as nn + +from modules import Attention, PointNetAModule, PointNetFPModule, PointNetSAModule, PVConv, SharedMLP, Swish def _linear_gn_relu(in_channels, out_channels): - return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8, out_channels), Swish()) def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): @@ -43,8 +44,16 @@ def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, wi return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) -def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, - width_multiplier=1, voxel_resolution_multiplier=1): +def create_pointnet_components( + blocks, + in_channels, + embed_dim, + with_se=False, + normalize=True, + eps=0, + width_multiplier=1, + voxel_resolution_multiplier=1, +): r, vr = width_multiplier, voxel_resolution_multiplier layers, concat_channels = [], 0 @@ -56,22 +65,38 @@ def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, no if voxel_resolution is None: block = SharedMLP else: - block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, - with_se=with_se, normalize=normalize, eps=eps) + block = functools.partial( + PVConv, + kernel_size=3, + resolution=int(vr * voxel_resolution), + attention=attention, + with_se=with_se, + normalize=normalize, + eps=eps, + ) if c == 0: layers.append(block(in_channels, out_channels)) else: - layers.append(block(in_channels+embed_dim, out_channels)) + layers.append(block(in_channels + embed_dim, out_channels)) in_channels = out_channels concat_channels += out_channels c += 1 return layers, in_channels, concat_channels -def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim=64, use_att=False, - dropout=0.1, with_se=False, normalize=True, eps=0, - width_multiplier=1, voxel_resolution_multiplier=1): +def create_pointnet2_sa_components( + sa_blocks, + extra_feature_channels, + embed_dim=64, + use_att=False, + dropout=0.1, + with_se=False, + normalize=True, + eps=0, + width_multiplier=1, + voxel_resolution_multiplier=1, +): r, vr = width_multiplier, voxel_resolution_multiplier in_channels = extra_feature_channels + 3 @@ -86,19 +111,26 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim= out_channels, num_blocks, voxel_resolution = conv_configs out_channels = int(r * out_channels) for p in range(num_blocks): - attention = (c+1) % 2 == 0 and use_att and p == 0 + attention = (c + 1) % 2 == 0 and use_att and p == 0 if voxel_resolution is None: block = SharedMLP else: - block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, - dropout=dropout, - with_se=with_se, with_se_relu=True, - normalize=normalize, eps=eps) + block = functools.partial( + PVConv, + kernel_size=3, + resolution=int(vr * voxel_resolution), + attention=attention, + dropout=dropout, + with_se=with_se, + with_se_relu=True, + normalize=normalize, + eps=eps, + ) if c == 0: sa_blocks.append(block(in_channels, out_channels)) - elif k ==0: - sa_blocks.append(block(in_channels+embed_dim, out_channels)) + elif k == 0: + sa_blocks.append(block(in_channels + embed_dim, out_channels)) in_channels = out_channels k += 1 extra_feature_channels = in_channels @@ -113,10 +145,16 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim= if num_centers is None: block = PointNetAModule else: - block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, - num_neighbors=num_neighbors) - sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, - include_coordinates=True)) + block = functools.partial( + PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors + ) + sa_blocks.append( + block( + in_channels=extra_feature_channels + (embed_dim if k == 0 else 0), + out_channels=out_channels, + include_coordinates=True, + ) + ) c += 1 in_channels = extra_feature_channels = sa_blocks[-1].out_channels if len(sa_blocks) == 1: @@ -127,10 +165,19 @@ def create_pointnet2_sa_components(sa_blocks, extra_feature_channels, embed_dim= return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers -def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, - dropout=0.1, - with_se=False, normalize=True, eps=0, - width_multiplier=1, voxel_resolution_multiplier=1): +def create_pointnet2_fp_modules( + fp_blocks, + in_channels, + sa_in_channels, + embed_dim=64, + use_att=False, + dropout=0.1, + with_se=False, + normalize=True, + eps=0, + width_multiplier=1, + voxel_resolution_multiplier=1, +): r, vr = width_multiplier, voxel_resolution_multiplier fp_layers = [] @@ -139,7 +186,9 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di fp_blocks = [] out_channels = tuple(int(r * oc) for oc in fp_configs) fp_blocks.append( - PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) + PointNetFPModule( + in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels + ) ) in_channels = out_channels[-1] @@ -147,14 +196,21 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di out_channels, num_blocks, voxel_resolution = conv_configs out_channels = int(r * out_channels) for p in range(num_blocks): - attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 + attention = (c + 1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 if voxel_resolution is None: block = SharedMLP else: - block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, - dropout=dropout, - with_se=with_se, with_se_relu=True, - normalize=normalize, eps=eps) + block = functools.partial( + PVConv, + kernel_size=3, + resolution=int(vr * voxel_resolution), + attention=attention, + dropout=dropout, + with_se=with_se, + with_se_relu=True, + normalize=normalize, + eps=eps, + ) fp_blocks.append(block(in_channels, out_channels)) in_channels = out_channels @@ -168,20 +224,31 @@ def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_di return fp_layers, in_channels - class PVCNN2Base(nn.Module): - - def __init__(self, num_classes, embed_dim, use_att, dropout=0.1, - extra_feature_channels=3, width_multiplier=1, voxel_resolution_multiplier=1): + def __init__( + self, + num_classes, + embed_dim, + use_att, + dropout=0.1, + extra_feature_channels=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + ): super().__init__() assert extra_feature_channels >= 0 self.embed_dim = embed_dim self.in_channels = extra_feature_channels + 3 sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( - sa_blocks=self.sa_blocks, extra_feature_channels=extra_feature_channels, with_se=True, embed_dim=embed_dim, - use_att=use_att, dropout=dropout, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + sa_blocks=self.sa_blocks, + extra_feature_channels=extra_feature_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) self.sa_layers = nn.ModuleList(sa_layers) @@ -190,15 +257,25 @@ class PVCNN2Base(nn.Module): # only use extra features in the last fp module sa_in_channels[0] = extra_feature_channels fp_layers, channels_fp_features = create_pointnet2_fp_modules( - fp_blocks=self.fp_blocks, in_channels=channels_sa_features, sa_in_channels=sa_in_channels, with_se=True, embed_dim=embed_dim, - use_att=use_att, dropout=dropout, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + fp_blocks=self.fp_blocks, + in_channels=channels_sa_features, + sa_in_channels=sa_in_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) self.fp_layers = nn.ModuleList(fp_layers) - - layers, _ = create_mlp_components(in_channels=channels_fp_features, out_channels=[128, dropout, num_classes], # was 0.5 - classifier=True, dim=2, width_multiplier=width_multiplier) + layers, _ = create_mlp_components( + in_channels=channels_fp_features, + out_channels=[128, dropout, num_classes], # was 0.5 + classifier=True, + dim=2, + width_multiplier=width_multiplier, + ) self.classifier = nn.Sequential(*layers) self.embedf = nn.Sequential( @@ -223,25 +300,30 @@ class PVCNN2Base(nn.Module): return emb def forward(self, inputs, t): - - temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:,:,None].expand(-1,-1,inputs.shape[-1]) + temb = self.embedf(self.get_timestep_embedding(t, inputs.device))[:, :, None].expand(-1, -1, inputs.shape[-1]) # inputs : [B, in_channels + S, N] coords, features = inputs[:, :3, :].contiguous(), inputs coords_list, in_features_list = [], [] - for i, sa_blocks in enumerate(self.sa_layers): + for i, sa_blocks in enumerate(self.sa_layers): in_features_list.append(features) coords_list.append(coords) if i == 0: - features, coords, temb = sa_blocks ((features, coords, temb)) + features, coords, temb = sa_blocks((features, coords, temb)) else: - features, coords, temb = sa_blocks ((torch.cat([features,temb],dim=1), coords, temb)) + features, coords, temb = sa_blocks((torch.cat([features, temb], dim=1), coords, temb)) in_features_list[0] = inputs[:, 3:, :].contiguous() if self.global_att is not None: features = self.global_att(features) - for fp_idx, fp_blocks in enumerate(self.fp_layers): - features, coords, temb = fp_blocks((coords_list[-1-fp_idx], coords, torch.cat([features,temb],dim=1), in_features_list[-1-fp_idx], temb)) + for fp_idx, fp_blocks in enumerate(self.fp_layers): + features, coords, temb = fp_blocks( + ( + coords_list[-1 - fp_idx], + coords, + torch.cat([features, temb], dim=1), + in_features_list[-1 - fp_idx], + temb, + ) + ) return self.classifier(features) - - diff --git a/modules/__init__.py b/modules/__init__.py index 89290fc..e69de29 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -1,8 +0,0 @@ -from modules.ball_query import BallQuery -from modules.frustum import FrustumPointNetLoss -from modules.loss import KLLoss -from modules.pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule -from modules.pvconv import PVConv, Attention, Swish, PVConvReLU -from modules.se import SE3d -from modules.shared_mlp import SharedMLP -from modules.voxelization import Voxelization diff --git a/modules/ball_query.py b/modules/ball_query.py index 20251d0..12b74ff 100644 --- a/modules/ball_query.py +++ b/modules/ball_query.py @@ -3,7 +3,7 @@ import torch.nn as nn import modules.functional as F -__all__ = ['BallQuery'] +__all__ = ["BallQuery"] class BallQuery(nn.Module): @@ -21,7 +21,7 @@ class BallQuery(nn.Module): neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) if points_features is None: - assert self.include_coordinates, 'No Features For Grouping' + assert self.include_coordinates, "No Features For Grouping" neighbor_features = neighbor_coordinates else: neighbor_features = F.grouping(points_features, neighbor_indices) @@ -30,5 +30,6 @@ class BallQuery(nn.Module): return neighbor_features, F.grouping(temb, neighbor_indices) def extra_repr(self): - return 'radius={}, num_neighbors={}{}'.format( - self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') + return "radius={}, num_neighbors={}{}".format( + self.radius, self.num_neighbors, ", include coordinates" if self.include_coordinates else "" + ) diff --git a/modules/frustum.py b/modules/frustum.py index e8d95d2..8494436 100644 --- a/modules/frustum.py +++ b/modules/frustum.py @@ -5,12 +5,20 @@ import torch.nn.functional as F import modules.functional as PF -__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d'] +__all__ = ["FrustumPointNetLoss", "get_box_corners_3d"] class FrustumPointNetLoss(nn.Module): - def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0, - corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0): + def __init__( + self, + num_heading_angle_bins, + num_size_templates, + size_templates, + box_loss_weight=1.0, + corners_loss_weight=10.0, + heading_residual_loss_weight=20.0, + size_residual_loss_weight=20.0, + ): super().__init__() self.box_loss_weight = box_loss_weight self.corners_loss_weight = corners_loss_weight @@ -19,28 +27,28 @@ class FrustumPointNetLoss(nn.Module): self.num_heading_angle_bins = num_heading_angle_bins self.num_size_templates = num_size_templates - self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3)) + self.register_buffer("size_templates", size_templates.view(self.num_size_templates, 3)) self.register_buffer( - 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) + "heading_angle_bin_centers", torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) ) def forward(self, inputs, targets): - mask_logits = inputs['mask_logits'] # (B, 2, N) - center_reg = inputs['center_reg'] # (B, 3) - center = inputs['center'] # (B, 3) - heading_scores = inputs['heading_scores'] # (B, NH) - heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH) - heading_residuals = inputs['heading_residuals'] # (B, NH) - size_scores = inputs['size_scores'] # (B, NS) - size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3) - size_residuals = inputs['size_residuals'] # (B, NS, 3) + mask_logits = inputs["mask_logits"] # (B, 2, N) + center_reg = inputs["center_reg"] # (B, 3) + center = inputs["center"] # (B, 3) + heading_scores = inputs["heading_scores"] # (B, NH) + heading_residuals_normalized = inputs["heading_residuals_normalized"] # (B, NH) + heading_residuals = inputs["heading_residuals"] # (B, NH) + size_scores = inputs["size_scores"] # (B, NS) + size_residuals_normalized = inputs["size_residuals_normalized"] # (B, NS, 3) + size_residuals = inputs["size_residuals"] # (B, NS, 3) - mask_logits_target = targets['mask_logits'] # (B, N) - center_target = targets['center'] # (B, 3) - heading_bin_id_target = targets['heading_bin_id'] # (B, ) - heading_residual_target = targets['heading_residual'] # (B, ) - size_template_id_target = targets['size_template_id'] # (B, ) - size_residual_target = targets['size_residual'] # (B, 3) + mask_logits_target = targets["mask_logits"] # (B, N) + center_target = targets["center"] # (B, 3) + heading_bin_id_target = targets["heading_bin_id"] # (B, ) + heading_residual_target = targets["heading_residual"] # (B, ) + size_template_id_target = targets["size_template_id"] # (B, ) + size_residual_target = targets["size_residual"] # (B, 3) batch_size = center.size(0) batch_id = torch.arange(batch_size, device=center.device) @@ -65,25 +73,32 @@ class FrustumPointNetLoss(nn.Module): ) # Bounding box losses - heading = (heading_residuals[batch_id, heading_bin_id_target] - + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, ) + heading = ( + heading_residuals[batch_id, heading_bin_id_target] + self.heading_angle_bin_centers[heading_bin_id_target] + ) # (B, ) # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets) - size = (size_residuals[batch_id, size_template_id_target] - + self.size_templates[size_template_id_target]) # (B, 3) + size = ( + size_residuals[batch_id, size_template_id_target] + self.size_templates[size_template_id_target] + ) # (B, 3) corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8) heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, ) size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3) - corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target, - sizes=size_target, with_flip=True) # (B, 3, 8) - corners_loss = PF.huber_loss(torch.min( - torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1) - ), delta=1.0) + corners_target, corners_target_flip = get_box_corners_3d( + centers=center_target, headings=heading_target, sizes=size_target, with_flip=True + ) # (B, 3, 8) + corners_loss = PF.huber_loss( + torch.min(torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)), + delta=1.0, + ) # Summing up loss = mask_loss + self.box_loss_weight * ( - center_loss + center_reg_loss + heading_loss + size_loss - + self.heading_residual_loss_weight * heading_residual_normalized_loss - + self.size_residual_loss_weight * size_residual_normalized_loss - + self.corners_loss_weight * corners_loss + center_loss + + center_reg_loss + + heading_loss + + size_loss + + self.heading_residual_loss_weight * heading_residual_normalized_loss + + self.size_residual_loss_weight * size_residual_normalized_loss + + self.corners_loss_weight * corners_loss ) return loss @@ -105,9 +120,9 @@ def get_box_corners_3d(centers, headings, sizes, with_flip=False): l = sizes[:, 0] # (N,) w = sizes[:, 1] # (N,) h = sizes[:, 2] # (N,) - x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8) - y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8) - z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8) + x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8) + y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8) + z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8) c = torch.cos(headings) # (N,) s = torch.sin(headings) # (N,) diff --git a/modules/functional/__init__.py b/modules/functional/__init__.py index ce707cc..e69de29 100644 --- a/modules/functional/__init__.py +++ b/modules/functional/__init__.py @@ -1,7 +0,0 @@ -from modules.functional.ball_query import ball_query -from modules.functional.devoxelization import trilinear_devoxelize -from modules.functional.grouping import grouping -from modules.functional.interpolatation import nearest_neighbor_interpolate -from modules.functional.loss import kl_loss, huber_loss -from modules.functional.sampling import gather, furthest_point_sample, logits_mask -from modules.functional.voxelization import avg_voxelize diff --git a/modules/functional/backend.py b/modules/functional/backend.py index d6232ca..4745bf6 100644 --- a/modules/functional/backend.py +++ b/modules/functional/backend.py @@ -3,24 +3,28 @@ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) -_backend = load(name='_pvcnn_backend', - extra_cflags=['-O3', '-std=c++17'], - extra_cuda_cflags=['--compiler-bindir=/softs/gcc/11.2.0/bin/gcc'], - sources=[os.path.join(_src_path,'src', f) for f in [ - 'ball_query/ball_query.cpp', - 'ball_query/ball_query.cu', - 'grouping/grouping.cpp', - 'grouping/grouping.cu', - 'interpolate/neighbor_interpolate.cpp', - 'interpolate/neighbor_interpolate.cu', - 'interpolate/trilinear_devox.cpp', - 'interpolate/trilinear_devox.cu', - 'sampling/sampling.cpp', - 'sampling/sampling.cu', - 'voxelization/vox.cpp', - 'voxelization/vox.cu', - 'bindings.cpp', - ]] - ) +_backend = load( + name="_pvcnn_backend", + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=["--compiler-bindir=/softs/gcc/11.2.0/bin/gcc"], + sources=[ + os.path.join(_src_path, "src", f) + for f in [ + "ball_query/ball_query.cpp", + "ball_query/ball_query.cu", + "grouping/grouping.cpp", + "grouping/grouping.cu", + "interpolate/neighbor_interpolate.cpp", + "interpolate/neighbor_interpolate.cu", + "interpolate/trilinear_devox.cpp", + "interpolate/trilinear_devox.cu", + "sampling/sampling.cpp", + "sampling/sampling.cu", + "voxelization/vox.cpp", + "voxelization/vox.cu", + "bindings.cpp", + ] + ], +) -__all__ = ['_backend'] +__all__ = ["_backend"] diff --git a/modules/functional/ball_query.py b/modules/functional/ball_query.py index a99df0d..ebdb8fb 100644 --- a/modules/functional/ball_query.py +++ b/modules/functional/ball_query.py @@ -1,19 +1,17 @@ -from torch.autograd import Function - from modules.functional.backend import _backend -__all__ = ['ball_query'] +__all__ = ["ball_query"] def ball_query(centers_coords, points_coords, radius, num_neighbors): - """ - :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] - :param points_coords: coordinates of points, FloatTensor[B, 3, N] - :param radius: float, radius of ball query - :param num_neighbors: int, maximum number of neighbors - :return: - neighbor_indices: indices of neighbors, IntTensor[B, M, U] - """ - centers_coords = centers_coords.contiguous() - points_coords = points_coords.contiguous() - return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors) + """ + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param radius: float, radius of ball query + :param num_neighbors: int, maximum number of neighbors + :return: + neighbor_indices: indices of neighbors, IntTensor[B, M, U] + """ + centers_coords = centers_coords.contiguous() + points_coords = points_coords.contiguous() + return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors) diff --git a/modules/functional/devoxelization.py b/modules/functional/devoxelization.py index b037f48..175408d 100644 --- a/modules/functional/devoxelization.py +++ b/modules/functional/devoxelization.py @@ -2,7 +2,7 @@ from torch.autograd import Function from modules.functional.backend import _backend -__all__ = ['trilinear_devoxelize'] +__all__ = ["trilinear_devoxelize"] class TrilinearDevoxelization(Function): @@ -29,7 +29,7 @@ class TrilinearDevoxelization(Function): @staticmethod def backward(ctx, grad_output): """ - :param ctx: + :param ctx: :param grad_output: gradient of outputs, FloatTensor[B, C, N] :return: gradient of inputs, FloatTensor[B, C, R, R, R] diff --git a/modules/functional/grouping.py b/modules/functional/grouping.py index 72855ea..dfd91e4 100644 --- a/modules/functional/grouping.py +++ b/modules/functional/grouping.py @@ -2,7 +2,7 @@ from torch.autograd import Function from modules.functional.backend import _backend -__all__ = ['grouping'] +__all__ = ["grouping"] class Grouping(Function): @@ -23,7 +23,7 @@ class Grouping(Function): @staticmethod def backward(ctx, grad_output): - indices, = ctx.saved_tensors + (indices,) = ctx.saved_tensors grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points) return grad_features, None diff --git a/modules/functional/interpolatation.py b/modules/functional/interpolatation.py index 5a42425..28267fd 100644 --- a/modules/functional/interpolatation.py +++ b/modules/functional/interpolatation.py @@ -2,7 +2,7 @@ from torch.autograd import Function from modules.functional.backend import _backend -__all__ = ['nearest_neighbor_interpolate'] +__all__ = ["nearest_neighbor_interpolate"] class NeighborInterpolation(Function): diff --git a/modules/functional/loss.py b/modules/functional/loss.py index 41112b3..937c7ee 100644 --- a/modules/functional/loss.py +++ b/modules/functional/loss.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F -__all__ = ['kl_loss', 'huber_loss'] +__all__ = ["kl_loss", "huber_loss"] def kl_loss(x, y): @@ -13,5 +13,5 @@ def kl_loss(x, y): def huber_loss(error, delta): abs_error = torch.abs(error) quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta)) - losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic) + losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic) return torch.mean(losses) diff --git a/modules/functional/sampling.py b/modules/functional/sampling.py index 160450b..5c57336 100644 --- a/modules/functional/sampling.py +++ b/modules/functional/sampling.py @@ -4,7 +4,7 @@ from torch.autograd import Function from modules.functional.backend import _backend -__all__ = ['gather', 'furthest_point_sample', 'logits_mask'] +__all__ = ["gather", "furthest_point_sample", "logits_mask"] class Gather(Function): @@ -26,7 +26,7 @@ class Gather(Function): @staticmethod def backward(ctx, grad_output): - indices, = ctx.saved_tensors + (indices,) = ctx.saved_tensors grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points) return grad_features, None @@ -60,11 +60,12 @@ def logits_mask(coords, logits, num_points_per_object): mask: mask to select points, BoolTensor[B, N] """ batch_size, _, num_points = coords.shape - mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] + mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1] masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N] - masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, - torch.ones_like(num_candidates)).float() # [B, C] + masked_coords_mean = ( + torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, torch.ones_like(num_candidates)).float() + ) # [B, C] selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32) for i in range(batch_size): current_mask = mask[i] # [N] @@ -74,10 +75,14 @@ def logits_mask(coords, logits, num_points_per_object): choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False) selected_indices[i] = current_candidates[choices] elif current_num_candidates > 0: - choices = np.concatenate([ - np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates), - np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False) - ]) + choices = np.concatenate( + [ + np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates), + np.random.choice( + current_num_candidates, num_points_per_object % current_num_candidates, replace=False + ), + ] + ) np.random.shuffle(choices) selected_indices[i] = current_candidates[choices] selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices) diff --git a/modules/functional/voxelization.py b/modules/functional/voxelization.py index 2452c68..6c048d5 100644 --- a/modules/functional/voxelization.py +++ b/modules/functional/voxelization.py @@ -2,7 +2,7 @@ from torch.autograd import Function from modules.functional.backend import _backend -__all__ = ['avg_voxelize'] +__all__ = ["avg_voxelize"] class AvgVoxelization(Function): diff --git a/modules/loss.py b/modules/loss.py index 173052d..b6b5e03 100644 --- a/modules/loss.py +++ b/modules/loss.py @@ -2,7 +2,7 @@ import torch.nn as nn import modules.functional as F -__all__ = ['KLLoss'] +__all__ = ["KLLoss"] class KLLoss(nn.Module): diff --git a/modules/pointnet.py b/modules/pointnet.py index 7925acf..c96d4a7 100644 --- a/modules/pointnet.py +++ b/modules/pointnet.py @@ -5,7 +5,7 @@ import modules.functional as F from modules.ball_query import BallQuery from modules.shared_mlp import SharedMLP -__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule'] +__all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"] class PointNetAModule(nn.Module): @@ -20,8 +20,9 @@ class PointNetAModule(nn.Module): total_out_channels = 0 for _out_channels in out_channels: mlps.append( - SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), - out_channels=_out_channels, dim=1) + SharedMLP( + in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1 + ) ) total_out_channels += _out_channels[-1] @@ -43,7 +44,7 @@ class PointNetAModule(nn.Module): return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords def extra_repr(self): - return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' + return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}" class PointNetSAModule(nn.Module): @@ -67,8 +68,9 @@ class PointNetSAModule(nn.Module): BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) ) mlps.append( - SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), - out_channels=_out_channels, dim=2) + SharedMLP( + in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2 + ) ) total_out_channels += _out_channels[-1] @@ -90,7 +92,7 @@ class PointNetSAModule(nn.Module): return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb def extra_repr(self): - return f'num_centers={self.num_centers}, out_channels={self.out_channels}' + return f"num_centers={self.num_centers}, out_channels={self.out_channels}" class PointNetFPModule(nn.Module): @@ -107,7 +109,5 @@ class PointNetFPModule(nn.Module): interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb) if points_features is not None: - interpolated_features = torch.cat( - [interpolated_features, points_features], dim=1 - ) + interpolated_features = torch.cat([interpolated_features, points_features], dim=1) return self.mlp(interpolated_features), points_coords, interpolated_temb diff --git a/modules/pvconv.py b/modules/pvconv.py index bcacfb0..bf1737e 100644 --- a/modules/pvconv.py +++ b/modules/pvconv.py @@ -1,16 +1,17 @@ -import torch.nn as nn import torch -import modules.functional as F -from modules.voxelization import Voxelization -from modules.shared_mlp import SharedMLP -from modules.se import SE3d +import torch.nn as nn -__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU'] +import modules.functional as F +from modules.se import SE3d +from modules.shared_mlp import SharedMLP +from modules.voxelization import Voxelization + +__all__ = ["PVConv", "Attention", "Swish", "PVConvReLU"] class Swish(nn.Module): - def forward(self,x): - return x * torch.sigmoid(x) + def forward(self, x): + return x * torch.sigmoid(x) class Attention(nn.Module): @@ -35,23 +36,19 @@ class Attention(nn.Module): self.sm = nn.Softmax(-1) - def forward(self, x): B, C = x.shape[:2] h = x + q = self.q(h).reshape(B, C, -1) + k = self.k(h).reshape(B, C, -1) + v = self.v(h).reshape(B, C, -1) - - - q = self.q(h).reshape(B,C,-1) - k = self.k(h).reshape(B,C,-1) - v = self.v(h).reshape(B,C,-1) - - qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5)) + qk = torch.matmul(q.permute(0, 2, 1), k) # * (int(C) ** (-0.5)) w = self.sm(qk) - h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:]) + h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B, C, *x.shape[2:]) h = self.out(h) @@ -61,9 +58,21 @@ class Attention(nn.Module): return x + class PVConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, - dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + resolution, + attention=False, + dropout=0.1, + with_se=False, + with_se_relu=False, + normalize=True, + eps=0, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -74,13 +83,13 @@ class PVConv(nn.Module): voxel_layers = [ nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.GroupNorm(num_groups=8, num_channels=out_channels), - Swish() + Swish(), ] voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] voxel_layers += [ nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.GroupNorm(num_groups=8, num_channels=out_channels), - Attention(out_channels, 8) if attention else Swish() + Attention(out_channels, 8) if attention else Swish(), ] if with_se: voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) @@ -96,10 +105,21 @@ class PVConv(nn.Module): return fused_features, coords, temb - class PVConvReLU(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2, - dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + resolution, + attention=False, + leak=0.2, + dropout=0.1, + with_se=False, + with_se_relu=False, + normalize=True, + eps=0, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -110,13 +130,13 @@ class PVConvReLU(nn.Module): voxel_layers = [ nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.BatchNorm3d(out_channels), - nn.LeakyReLU(leak, True) + nn.LeakyReLU(leak, True), ] voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] voxel_layers += [ nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), nn.BatchNorm3d(out_channels), - Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True) + Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True), ] if with_se: voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) diff --git a/modules/se.py b/modules/se.py index c34eef7..11f2ded 100644 --- a/modules/se.py +++ b/modules/se.py @@ -1,18 +1,22 @@ -import torch.nn as nn import torch -__all__ = ['SE3d'] +import torch.nn as nn + +__all__ = ["SE3d"] + class Swish(nn.Module): - def forward(self,x): - return x * torch.sigmoid(x) + def forward(self, x): + return x * torch.sigmoid(x) + + class SE3d(nn.Module): def __init__(self, channel, reduction=8, use_relu=False): super().__init__() self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), - nn.ReLU(True) if use_relu else Swish() , + nn.ReLU(True) if use_relu else Swish(), nn.Linear(channel // reduction, channel, bias=False), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, inputs): diff --git a/modules/shared_mlp.py b/modules/shared_mlp.py index 1fcc35e..8ef2780 100644 --- a/modules/shared_mlp.py +++ b/modules/shared_mlp.py @@ -1,12 +1,13 @@ -import torch.nn as nn import torch +import torch.nn as nn -__all__ = ['SharedMLP'] +__all__ = ["SharedMLP"] class Swish(nn.Module): - def forward(self,x): - return x * torch.sigmoid(x) + def forward(self, x): + return x * torch.sigmoid(x) + class SharedMLP(nn.Module): def __init__(self, in_channels, out_channels, dim=1): @@ -23,11 +24,13 @@ class SharedMLP(nn.Module): out_channels = [out_channels] layers = [] for oc in out_channels: - layers.extend([ - conv(in_channels, oc, 1), - bn(8, oc), - Swish(), - ]) + layers.extend( + [ + conv(in_channels, oc, 1), + bn(8, oc), + Swish(), + ] + ) in_channels = oc self.layers = nn.Sequential(*layers) diff --git a/modules/voxelization.py b/modules/voxelization.py index 7efc614..830341e 100644 --- a/modules/voxelization.py +++ b/modules/voxelization.py @@ -3,7 +3,7 @@ import torch.nn as nn import modules.functional as F -__all__ = ['Voxelization'] +__all__ = ["Voxelization"] class Voxelization(nn.Module): @@ -17,7 +17,10 @@ class Voxelization(nn.Module): coords = coords.detach() norm_coords = coords - coords.mean(2, keepdim=True) if self.normalize: - norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5 + norm_coords = ( + norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + + 0.5 + ) else: norm_coords = (norm_coords + 1) / 2.0 norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) @@ -25,4 +28,4 @@ class Voxelization(nn.Module): return F.avg_voxelize(features, vox_coords, self.r), norm_coords def extra_repr(self): - return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '') + return "resolution={}{}".format(self.r, ", normalized eps = {}".format(self.eps) if self.normalize else "") diff --git a/test_completion.py b/test_completion.py index 4bbd5a1..0f76d35 100644 --- a/test_completion.py +++ b/test_completion.py @@ -1,26 +1,27 @@ - +import argparse from pprint import pprint -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD import torch.nn as nn import torch.utils.data - -import argparse from torch.distributions import Normal -from utils.file_utils import * -from model.pvcnn_completion import PVCNN2Base from datasets.shapenet_data_pc import ShapeNet15kPointClouds from datasets.shapenet_data_sv import * -''' +from metrics.evaluation_metrics import EMD_CD, compute_all_metrics +from model.pvcnn_completion import PVCNN2Base +from utils.file_utils import * + +""" models -''' +""" + + def normal_kl(mean1, logvar1, mean2, logvar2): """ KL divergence between normal distributions parameterized by mean and log-variance. """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2)) + def discretized_gaussian_log_likelihood(x, *, means, log_scales): # Assumes data is integers [0, 1] @@ -31,21 +32,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 0.5) cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) + min_in = inv_stdv * (centered_x - 0.5) cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + x < 0.001, + log_cdf_plus, + torch.where( + x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12)) + ), + ) assert log_probs.shape == x.shape return log_probs - class GaussianDiffusion: def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): self.loss_type = loss_type @@ -54,15 +57,15 @@ class GaussianDiffusion: assert isinstance(betas, np.ndarray) self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.sv_points = sv_points # initialize twice the actual length so we can keep running for eval # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - alphas = 1. - betas + alphas = 1.0 - betas alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float() self.betas = torch.from_numpy(betas).float() self.alphas_cumprod = alphas_cumprod.float() @@ -70,21 +73,23 @@ class GaussianDiffusion: # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float() betas = torch.from_numpy(betas).float() alphas = torch.from_numpy(alphas).float() # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.posterior_variance = posterior_variance # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)) + ) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) @staticmethod def _extract(a, t, x_shape): @@ -92,17 +97,15 @@ class GaussianDiffusion: Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ - bs, = t.shape + (bs,) = t.shape assert x_shape[0] == bs out = torch.gather(a, 0, t) assert out.shape == torch.Size([bs]) return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - def q_mean_variance(self, x_start, t): mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) return mean, variance, log_variance @@ -114,54 +117,59 @@ class GaussianDiffusion: noise = torch.randn(x_start.shape, device=x_start.device) assert noise.shape == x_start.shape return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise ) - def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) + posterior_log_variance_clipped = self._extract( + self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + model_output = denoise_fn(data, t)[:, :, self.sv_points :] - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: + if self.model_var_type in ["fixedsmall", "fixedlarge"]: # below: only log_variance is used in the KL computations model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + "fixedlarge": ( + self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device), + ), + "fixedsmall": ( + self.posterior_variance.to(data.device), + self.posterior_log_variance_clipped.to(data.device), + ), }[self.model_var_type] model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) else: raise NotImplementedError(self.model_var_type) - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + if self.model_mean_type == "eps": + x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output) - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t) else: raise NotImplementedError(self.loss_type) - assert model_mean.shape == x_recon.shape assert model_variance.shape == model_log_variance.shape if return_pred_xstart: @@ -172,30 +180,31 @@ class GaussianDiffusion: def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t + - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps ) - ''' samples ''' + """ samples """ def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): """ Sample from the model """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) # no noise when t == 0 nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1) return (sample, pred_xstart) if return_pred_xstart else sample - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): + def p_sample_loop( + self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False + ): """ Generate samples keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps @@ -206,14 +215,21 @@ class GaussianDiffusion: img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ) - assert img_t[:,:,self.sv_points:].shape == shape + assert img_t[:, :, self.sv_points :].shape == shape return img_t - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): + def p_sample_loop_trajectory( + self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False + ): """ Generate samples, returning intermediate images Useful for visualizing how denoised images evolve over time @@ -223,31 +239,38 @@ class GaussianDiffusion: """ assert isinstance(shape, (tuple, list)) - total_steps = self.num_timesteps if not keep_running else len(self.betas) + total_steps = self.num_timesteps if not keep_running else len(self.betas) img_t = noise_fn(size=shape, dtype=torch.float, device=device) imgs = [img_t] - for t in reversed(range(0,total_steps)): - + for t in reversed(range(0, total_steps)): t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ) + if t % freq == 0 or t == total_steps - 1: imgs.append(img_t) assert imgs[-1].shape == shape return imgs - '''losses''' + """losses""" def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t + ) model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0) return (kl, pred_xstart) if return_pred_xstart else kl @@ -259,66 +282,87 @@ class GaussianDiffusion: assert t.shape == torch.Size([B]) if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + noise = torch.randn( + data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device + ) - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise) - if self.loss_type == 'mse': + if self.loss_type == "mse": # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[ + :, :, self.sv_points : + ] - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': + losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == "kl": losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) + denoise_fn=denoise_fn, + data_start=data_start, + data_t=data_t, + t=t, + clip_denoised=False, + return_pred_xstart=False, + ) else: raise NotImplementedError(self.loss_type) assert losses.shape == torch.Size([B]) return losses - '''debug''' + """debug""" def _prior_bpd(self, x_start): - with torch.no_grad(): B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + kl_prior = normal_kl( + mean1=qt_mean, + logvar1=qt_log_variance, + mean2=torch.tensor([0.0]).to(qt_mean), + logvar2=torch.tensor([0.0]).to(qt_log_variance), + ) assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0) def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - with torch.no_grad(): B, T = x_start.shape[0], self.num_timesteps - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) for t in reversed(range(T)): - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + data_t = torch.cat( + [x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)], + dim=-1, + ) new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) + denoise_fn, + data_start=x_start, + data_t=data_t, + t=t_b, + clip_denoised=clip_denoised, + return_pred_xstart=True, + ) # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean( + dim=list(range(1, len(pred_xstart.shape))) + ) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float() vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :]) total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + assert vals_bt_.shape == mse_bt_.shape == torch.Size( + [B, T] + ) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() @@ -336,39 +380,53 @@ class PVCNN2(PVCNN2Base): ((128, 128, 64), (64, 2, 32)), ] - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): + def __init__( + self, + num_classes, + sv_points, + embed_dim, + use_att, + dropout, + extra_feature_channels=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + ): super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + num_classes=num_classes, + sv_points=sv_points, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str): super(Model, self).__init__() self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) + self.model = PVCNN2( + num_classes=args.nc, + sv_points=args.svpoints, + embed_dim=args.embed_dim, + use_att=args.attention, + dropout=args.dropout, + extra_feature_channels=0, + ) def prior_kl(self, x0): return self.diffusion._prior_bpd(x0) def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt} def _denoise(self, data, t): - B, D,N= data.shape + B, D, N = data.shape assert data.dtype == torch.float assert t.shape == torch.Size([B]) and t.dtype == torch.int64 @@ -381,20 +439,22 @@ class Model(nn.Module): t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises) - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises) assert losses.shape == t.shape == torch.Size([B]) return losses - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False): + return self.diffusion.p_sample_loop( + partial_x, + self._denoise, + shape=shape, + device=device, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running, + ) def train(self): self.model.train() @@ -405,21 +465,19 @@ class Model(nn.Module): def multi_gpu_wrapper(self, f): self.model = f(self.model) -def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': - betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': +def get_betas(schedule_type, b_start, b_end, time_num): + if schedule_type == "linear": + betas = np.linspace(b_start, b_end, time_num) + elif schedule_type == "warm0.1": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.1) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - + elif schedule_type == "warm0.2": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.2) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - + elif schedule_type == "warm0.5": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.5) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) @@ -428,22 +486,29 @@ def get_betas(schedule_type, b_start, b_end, time_num): return betas - ############################################################################# -def get_mvr_dataset(pc_dataroot, views_root, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot, - categories=[category], split='train', + +def get_mvr_dataset(pc_dataroot, views_root, npoints, category): + tr_dataset = ShapeNet15kPointClouds( + root_dir=pc_dataroot, + categories=[category], + split="train", tr_sample_size=npoints, te_sample_size=npoints, - scale=1., + scale=1.0, normalize_per_shape=False, normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root, - cache=os.path.join(pc_dataroot, '../cache'), split='val', + random_subsample=True, + ) + te_dataset = ShapeNet_Multiview_Points( + root_pc=pc_dataroot, + root_views=views_root, + cache=os.path.join(pc_dataroot, "../cache"), + split="val", categories=[category], - npoints=npoints, sv_samples=200, + npoints=npoints, + sv_samples=200, all_points_mean=tr_dataset.all_points_mean, all_points_std=tr_dataset.all_points_std, ) @@ -451,39 +516,41 @@ def get_mvr_dataset(pc_dataroot, views_root, npoints,category): def evaluate_recon_mvr(opt, model, save_dir, logger): - test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, - opt.npoints, opt.category) + test_dataset = get_mvr_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.category) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) + test_dataloader = torch.utils.data.DataLoader( + test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False + ) ref = [] samples = [] masked = [] - k = 0 - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Reconstructing Samples'): + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Reconstructing Samples"): + gt_all = data["test_points"] + x_all = data["sv_points"] - gt_all = data['test_points'] - x_all = data['sv_points'] - - B,V,N,C = x_all.shape - gt_all = gt_all[:,None,:,:].expand(-1, V, -1,-1) + B, V, N, C = x_all.shape + gt_all = gt_all[:, None, :, :].expand(-1, V, -1, -1) x = x_all.reshape(B * V, N, C).transpose(1, 2).contiguous() - m, s = data['mean'].float(), data['std'].float() + m, s = data["mean"].float(), data["std"].float() - recon = model.gen_samples(x[:, :, :opt.svpoints].cuda(), x[:, :, opt.svpoints:].shape, 'cuda', - clip_denoised=False).detach().cpu() + recon = ( + model.gen_samples( + x[:, :, : opt.svpoints].cuda(), x[:, :, opt.svpoints :].shape, "cuda", clip_denoised=False + ) + .detach() + .cpu() + ) recon = recon.transpose(1, 2).contiguous() x = x.transpose(1, 2).contiguous() + x_adj = x.reshape(B, V, N, C) * s + m + recon_adj = recon.reshape(B, V, N, C) * s + m - x_adj = x.reshape(B,V,N,C)* s + m - recon_adj = recon.reshape(B,V,N,C)* s + m - - ref.append( gt_all * s + m) - masked.append(x_adj[:,:,:test_dataloader.dataset.sv_samples,:]) + ref.append(gt_all * s + m) + masked.append(x_adj[:, :, : test_dataloader.dataset.sv_samples, :]) samples.append(recon_adj) ref_pcs = torch.cat(ref, dim=0) @@ -492,31 +559,40 @@ def evaluate_recon_mvr(opt, model, save_dir, logger): B, V, N, C = ref_pcs.shape + torch.save(ref_pcs.reshape(B, V, N, C), os.path.join(save_dir, "recon_gt.pth")) - torch.save(ref_pcs.reshape(B,V, N, C), os.path.join(save_dir, 'recon_gt.pth')) - - torch.save(masked.reshape(B,V, *masked.shape[2:]), os.path.join(save_dir, 'recon_masked.pth')) + torch.save(masked.reshape(B, V, *masked.shape[2:]), os.path.join(save_dir, "recon_masked.pth")) # Compute metrics - results = EMD_CD(sample_pcs.reshape(B*V, N, C), - ref_pcs.reshape(B*V, N, C), opt.batch_size, reduced=False) + results = EMD_CD(sample_pcs.reshape(B * V, N, C), ref_pcs.reshape(B * V, N, C), opt.batch_size, reduced=False) - results = {ky: val.reshape(B,V) if val.shape == torch.Size([B*V,]) else val for ky, val in results.items()} + results = { + ky: val.reshape(B, V) + if val.shape + == torch.Size( + [ + B * V, + ] + ) + else val + for ky, val in results.items() + } pprint({key: val.mean().item() for key, val in results.items()}) logger.info({key: val.mean().item() for key, val in results.items()}) - results['pc'] = sample_pcs - torch.save(results, os.path.join(save_dir, 'ours_results.pth')) + results["pc"] = sample_pcs + torch.save(results, os.path.join(save_dir, "ours_results.pth")) del ref_pcs, masked, results + def evaluate_saved(opt, saved_dir): # ours_base = '/viscam/u/alexzhou907/research/diffusion/shape_completion/output/test_chair/2020-11-04-02-10-38/syn' - gt_pth = saved_dir + '/recon_gt.pth' - ours_pth = saved_dir + '/ours_results.pth' - gt = torch.load(gt_pth).permute(1,0,2,3) - ours = torch.load(ours_pth)['pc'].permute(1,0,2,3) + gt_pth = saved_dir + "/recon_gt.pth" + ours_pth = saved_dir + "/ours_results.pth" + gt = torch.load(gt_pth).permute(1, 0, 2, 3) + ours = torch.load(ours_pth)["pc"].permute(1, 0, 2, 3) all_res = {} for i, (gt_, ours_) in enumerate(zip(gt, ours)): @@ -534,7 +610,6 @@ def evaluate_saved(opt, saved_dir): pprint({key: val.mean().item() for key, val in all_res.items()}) - def main(opt): exp_id = os.path.splitext(os.path.basename(__file__))[0] dir_id = os.path.dirname(__file__) @@ -542,7 +617,7 @@ def main(opt): copy_source(__file__, output_dir) logger = setup_logging(output_dir) - outf_syn, = setup_output_subdirs(output_dir, 'syn') + (outf_syn,) = setup_output_subdirs(output_dir, "syn") betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) @@ -559,12 +634,10 @@ def main(opt): model.eval() with torch.no_grad(): - logger.info("Resume Path:%s" % opt.model) resumed_param = torch.load(opt.model) - model.load_state_dict(resumed_param['model_state']) - + model.load_state_dict(resumed_param["model_state"]) if opt.eval_recon_mvr: # Evaluate generation @@ -575,47 +648,44 @@ def main(opt): def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/') - parser.add_argument('--dataroot_sv', default='GenReData/') - parser.add_argument('--category', default='chair') + parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/") + parser.add_argument("--dataroot_sv", default="GenReData/") + parser.add_argument("--category", default="chair") - parser.add_argument('--batch_size', type=int, default=50, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + parser.add_argument("--batch_size", type=int, default=50, help="input batch size") + parser.add_argument("--workers", type=int, default=16, help="workers") + parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for") - parser.add_argument('--eval_recon_mvr', default=True) - parser.add_argument('--eval_saved', default=True) + parser.add_argument("--eval_recon_mvr", default=True) + parser.add_argument("--eval_saved", default=True) - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) + parser.add_argument("--nc", default=3) + parser.add_argument("--npoints", default=2048) + parser.add_argument("--svpoints", default=200) + """model""" + parser.add_argument("--beta_start", default=0.0001) + parser.add_argument("--beta_end", default=0.02) + parser.add_argument("--schedule_type", default="linear") + parser.add_argument("--time_num", default=1000) - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') + # params + parser.add_argument("--attention", default=True) + parser.add_argument("--dropout", default=0.1) + parser.add_argument("--embed_dim", type=int, default=64) + parser.add_argument("--loss_type", default="mse") + parser.add_argument("--model_mean_type", default="eps") + parser.add_argument("--model_var_type", default="fixedsmall") + parser.add_argument("--model", default="", required=True, help="path to model (to continue training)") - parser.add_argument('--model', default='', required=True, help="path to model (to continue training)") + """eval""" - '''eval''' + parser.add_argument("--eval_path", default="") - parser.add_argument('--eval_path', - default='') + parser.add_argument("--manualSeed", default=42, type=int, help="random seed") - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)") opt = parser.parse_args() @@ -625,7 +695,9 @@ def parse_args(): opt.cuda = False return opt -if __name__ == '__main__': + + +if __name__ == "__main__": opt = parse_args() main(opt) diff --git a/test_generation.py b/test_generation.py index 7d35c54..40aafde 100644 --- a/test_generation.py +++ b/test_generation.py @@ -1,31 +1,30 @@ -import torch +import argparse from pprint import pprint -from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD -from metrics.evaluation_metrics import compute_all_metrics, EMD_CD +import torch import torch.nn as nn import torch.utils.data - -import argparse from torch.distributions import Normal - -from utils.file_utils import * -from utils.visualize import * -from model.pvcnn_generation import PVCNN2Base - from tqdm import tqdm from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from metrics.evaluation_metrics import compute_all_metrics +from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD +from model.pvcnn_generation import PVCNN2Base +from utils.file_utils import * +from utils.visualize import * -''' +""" models -''' +""" + + def normal_kl(mean1, logvar1, mean2, logvar2): """ KL divergence between normal distributions parameterized by mean and log-variance. """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2)) + def discretized_gaussian_log_likelihood(x, *, means, log_scales): # Assumes data is integers [0, 1] @@ -36,37 +35,40 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 0.5) cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) + min_in = inv_stdv * (centered_x - 0.5) cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + x < 0.001, + log_cdf_plus, + torch.where( + x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12)) + ), + ) assert log_probs.shape == x.shape return log_probs class GaussianDiffusion: - def __init__(self,betas, loss_type, model_mean_type, model_var_type): + def __init__(self, betas, loss_type, model_mean_type, model_var_type): self.loss_type = loss_type self.model_mean_type = model_mean_type self.model_var_type = model_var_type assert isinstance(betas, np.ndarray) self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) # initialize twice the actual length so we can keep running for eval # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - alphas = 1. - betas + alphas = 1.0 - betas alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float() self.betas = torch.from_numpy(betas).float() self.alphas_cumprod = alphas_cumprod.float() @@ -74,21 +76,23 @@ class GaussianDiffusion: # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float() betas = torch.from_numpy(betas).float() alphas = torch.from_numpy(alphas).float() # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.posterior_variance = posterior_variance # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)) + ) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) @staticmethod def _extract(a, t, x_shape): @@ -96,17 +100,15 @@ class GaussianDiffusion: Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ - bs, = t.shape + (bs,) = t.shape assert x_shape[0] == bs out = torch.gather(a, 0, t) assert out.shape == torch.Size([bs]) return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - def q_mean_variance(self, x_start, t): mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) return mean, variance, log_variance @@ -118,56 +120,62 @@ class GaussianDiffusion: noise = torch.randn(x_start.shape, device=x_start.device) assert noise.shape == x_start.shape return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise ) - def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) + posterior_log_variance_clipped = self._extract( + self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - model_output = denoise_fn(data, t) - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: + if self.model_var_type in ["fixedsmall", "fixedlarge"]: # below: only log_variance is used in the KL computations model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + "fixedlarge": ( + self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device), + ), + "fixedsmall": ( + self.posterior_variance.to(data.device), + self.posterior_log_variance_clipped.to(data.device), + ), }[self.model_var_type] model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) else: raise NotImplementedError(self.model_var_type) - if self.model_mean_type == 'eps': + if self.model_mean_type == "eps": x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) if clip_denoised: - x_recon = torch.clamp(x_recon, -.5, .5) + x_recon = torch.clamp(x_recon, -0.5, 0.5) model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) else: raise NotImplementedError(self.loss_type) - assert model_mean.shape == x_recon.shape == data.shape assert model_variance.shape == model_log_variance.shape == data.shape if return_pred_xstart: @@ -178,18 +186,19 @@ class GaussianDiffusion: def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t + - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps ) - ''' samples ''' + """ samples """ def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False, use_var=True): """ Sample from the model """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) assert noise.shape == data.shape # no noise when t == 0 @@ -201,10 +210,17 @@ class GaussianDiffusion: assert sample.shape == pred_xstart.shape return (sample, pred_xstart) if return_pred_xstart else sample - - def p_sample_loop(self, denoise_fn, shape, device, - noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=True, max_timestep=None, keep_running=False): + def p_sample_loop( + self, + denoise_fn, + shape, + device, + noise_fn=torch.randn, + constrain_fn=lambda x, t: x, + clip_denoised=True, + max_timestep=None, + keep_running=False, + ): """ Generate samples keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps @@ -220,28 +236,38 @@ class GaussianDiffusion: for t in reversed(range(0, final_time if not keep_running else len(self.betas))): img_t = constrain_fn(img_t, t) t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False).detach() - + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ).detach() assert img_t.shape == shape return img_t - def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t:x): - + def reconstruct(self, x0, t, denoise_fn, noise_fn=torch.randn, constrain_fn=lambda x, t: x): assert t >= 1 - t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t-1) + t_vec = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(t - 1) encoding = self.q_sample(x0, t_vec) img_t = encoding - for k in reversed(range(0,t)): + for k in reversed(range(0, t)): img_t = constrain_fn(img_t, k) t_ = torch.empty(x0.shape[0], dtype=torch.int64, device=x0.device).fill_(k) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=False, return_pred_xstart=False, use_var=True).detach() - + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=False, + return_pred_xstart=False, + use_var=True, + ).detach() return img_t @@ -260,40 +286,50 @@ class PVCNN2(PVCNN2Base): ((128, 128, 64), (64, 2, 32)), ] - def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): + def __init__( + self, + num_classes, + embed_dim, + use_att, + dropout, + extra_feature_channels=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + ): super().__init__( - num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + num_classes=num_classes, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) - class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str): super(Model, self).__init__() self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) - self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) + self.model = PVCNN2( + num_classes=args.nc, + embed_dim=args.embed_dim, + use_att=args.attention, + dropout=args.dropout, + extra_feature_channels=0, + ) def prior_kl(self, x0): return self.diffusion._prior_bpd(x0) def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt} def _denoise(self, data, t): - B, D,N= data.shape + B, D, N = data.shape assert data.dtype == torch.float assert t.shape == torch.Size([B]) and t.dtype == torch.int64 @@ -307,23 +343,34 @@ class Model(nn.Module): t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises) - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises) assert losses.shape == t.shape == torch.Size([B]) return losses - def gen_samples(self, shape, device, noise_fn=torch.randn, constrain_fn=lambda x, t:x, - clip_denoised=False, max_timestep=None, - keep_running=False): - return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, - constrain_fn=constrain_fn, - clip_denoised=clip_denoised, max_timestep=max_timestep, - keep_running=keep_running) - - def reconstruct(self, x0, t, constrain_fn=lambda x, t:x): + def gen_samples( + self, + shape, + device, + noise_fn=torch.randn, + constrain_fn=lambda x, t: x, + clip_denoised=False, + max_timestep=None, + keep_running=False, + ): + return self.diffusion.p_sample_loop( + self._denoise, + shape=shape, + device=device, + noise_fn=noise_fn, + constrain_fn=constrain_fn, + clip_denoised=clip_denoised, + max_timestep=max_timestep, + keep_running=keep_running, + ) + def reconstruct(self, x0, t, constrain_fn=lambda x, t: x): return self.diffusion.reconstruct(x0, t, self._denoise, constrain_fn=constrain_fn) def train(self): @@ -337,20 +384,17 @@ class Model(nn.Module): def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': + if schedule_type == "linear": betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - + elif schedule_type == "warm0.1": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.1) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - + elif schedule_type == "warm0.2": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.2) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - + elif schedule_type == "warm0.5": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.5) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) @@ -358,111 +402,109 @@ def get_betas(schedule_type, b_start, b_end, time_num): raise NotImplementedError(schedule_type) return betas + def get_constrain_function(ground_truth, mask, eps, num_steps=1): - ''' + """ :param target_shape_constraint: target voxels :return: constrained x - ''' + """ # eps_all = list(reversed(np.linspace(0,np.float_power(eps, 1/2), 500)**2)) - eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000)**2 )) - def constrain_fn(x, t): - eps_ = eps_all[t] if (t<1000) else 0 - for _ in range(num_steps): - x = x - eps_ * ((x - ground_truth) * mask) + eps_all = list(reversed(np.linspace(0, np.sqrt(eps), 1000) ** 2)) + def constrain_fn(x, t): + eps_ = eps_all[t] if (t < 1000) else 0 + for _ in range(num_steps): + x = x - eps_ * ((x - ground_truth) * mask) return x + return constrain_fn ############################################################################# -def get_dataset(dataroot, npoints,category,use_mask=False): - tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=[category], split='train', + +def get_dataset(dataroot, npoints, category, use_mask=False): + tr_dataset = ShapeNet15kPointClouds( + root_dir=dataroot, + categories=[category], + split="train", tr_sample_size=npoints, te_sample_size=npoints, - scale=1., + scale=1.0, normalize_per_shape=False, normalize_std_per_axis=False, - random_subsample=True, use_mask = use_mask) - te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=[category], split='val', + random_subsample=True, + use_mask=use_mask, + ) + te_dataset = ShapeNet15kPointClouds( + root_dir=dataroot, + categories=[category], + split="val", tr_sample_size=npoints, te_sample_size=npoints, - scale=1., + scale=1.0, normalize_per_shape=False, normalize_std_per_axis=False, all_points_mean=tr_dataset.all_points_mean, all_points_std=tr_dataset.all_points_std, - use_mask=use_mask + use_mask=use_mask, ) return tr_dataset, te_dataset - def evaluate_gen(opt, ref_pcs, logger): - if ref_pcs is None: _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category, use_mask=False) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) + test_dataloader = torch.utils.data.DataLoader( + test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False + ) ref = [] - for data in tqdm(test_dataloader, total=len(test_dataloader), desc='Generating Samples'): - x = data['test_points'] - m, s = data['mean'].float(), data['std'].float() + for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"): + x = data["test_points"] + m, s = data["mean"].float(), data["std"].float() - ref.append(x*s + m) + ref.append(x * s + m) ref_pcs = torch.cat(ref, dim=0).contiguous() - logger.info("Loading sample path: %s" - % (opt.eval_path)) + logger.info("Loading sample path: %s" % (opt.eval_path)) sample_pcs = torch.load(opt.eval_path).contiguous() - logger.info("Generation sample size:%s reference size: %s" - % (sample_pcs.size(), ref_pcs.size())) - + logger.info("Generation sample size:%s reference size: %s" % (sample_pcs.size(), ref_pcs.size())) # Compute metrics results = compute_all_metrics(sample_pcs, ref_pcs, opt.batch_size) - results = {k: (v.cpu().detach().item() - if not isinstance(v, float) else v) for k, v in results.items()} + results = {k: (v.cpu().detach().item() if not isinstance(v, float) else v) for k, v in results.items()} pprint(results) logger.info(results) jsd = JSD(sample_pcs.numpy(), ref_pcs.numpy()) - pprint('JSD: {}'.format(jsd)) - logger.info('JSD: {}'.format(jsd)) - + pprint("JSD: {}".format(jsd)) + logger.info("JSD: {}".format(jsd)) def generate(model, opt): - _, test_dataset = get_dataset(opt.dataroot, opt.npoints, opt.category) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batch_size, - shuffle=False, num_workers=int(opt.workers), drop_last=False) + test_dataloader = torch.utils.data.DataLoader( + test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False + ) with torch.no_grad(): - samples = [] ref = [] - for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc='Generating Samples'): - - x = data['test_points'].transpose(1,2) - m, s = data['mean'].float(), data['std'].float() - - gen = model.gen_samples(x.shape, - 'cuda', clip_denoised=False).detach().cpu() - - gen = gen.transpose(1,2).contiguous() - x = x.transpose(1,2).contiguous() + for i, data in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Generating Samples"): + x = data["test_points"].transpose(1, 2) + m, s = data["mean"].float(), data["std"].float() + gen = model.gen_samples(x.shape, "cuda", clip_denoised=False).detach().cpu() + gen = gen.transpose(1, 2).contiguous() + x = x.transpose(1, 2).contiguous() gen = gen * s + m x = x * s + m @@ -482,20 +524,20 @@ def generate(model, opt): # 1, # 0.5, # ) - # visualize using matplotlib - import matplotlib.pyplot as plt - import matplotlib.cm as cm import matplotlib - matplotlib.use('TkAgg') + import matplotlib.cm as cm + import matplotlib.pyplot as plt + + matplotlib.use("TkAgg") for idx, pc in enumerate(gen[:64]): print(f"Visualizing point cloud {idx}...") fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - ax.scatter(pc[:,0], pc[:,1], pc[:,2], c=pc[:,2], cmap=cm.jet) - ax.set_aspect('equal') - ax.axis('off') + ax = fig.add_subplot(111, projection="3d") + ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet) + ax.set_aspect("equal") + ax.axis("off") # ax.set_xlim(-1, 1) # ax.set_ylim(-1, 1) # ax.set_zlim(-1, 1) @@ -507,17 +549,14 @@ def generate(model, opt): torch.save(samples, opt.eval_path) - - return ref def main(opt): - - if opt.category == 'airplane': + if opt.category == "airplane": opt.beta_start = 1e-5 opt.beta_end = 0.008 - opt.schedule_type = 'warm0.1' + opt.schedule_type = "warm0.1" exp_id = os.path.splitext(os.path.basename(__file__))[0] dir_id = os.path.dirname(__file__) @@ -525,7 +564,7 @@ def main(opt): copy_source(__file__, output_dir) logger = setup_logging(output_dir) - outf_syn, = setup_output_subdirs(output_dir, 'syn') + (outf_syn,) = setup_output_subdirs(output_dir, "syn") betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) @@ -542,64 +581,59 @@ def main(opt): model.eval() with torch.no_grad(): - logger.info("Resume Path:%s" % opt.model) resumed_param = torch.load(opt.model) - model.load_state_dict(resumed_param['model_state']) - + model.load_state_dict(resumed_param["model_state"]) ref = None if opt.generate: - opt.eval_path = os.path.join(outf_syn, 'samples.pth') + opt.eval_path = os.path.join(outf_syn, "samples.pth") Path(opt.eval_path).parent.mkdir(parents=True, exist_ok=True) - ref=generate(model, opt) - + ref = generate(model, opt) + if opt.eval_gen: # Evaluate generation evaluate_gen(opt, ref, logger) def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/') - parser.add_argument('--category', default='chair') + parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/") + parser.add_argument("--category", default="chair") - parser.add_argument('--batch_size', type=int, default=50, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + parser.add_argument("--batch_size", type=int, default=50, help="input batch size") + parser.add_argument("--workers", type=int, default=16, help="workers") + parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for") - parser.add_argument('--generate',default=True) - parser.add_argument('--eval_gen', default=True) + parser.add_argument("--generate", default=True) + parser.add_argument("--eval_gen", default=True) - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) + parser.add_argument("--nc", default=3) + parser.add_argument("--npoints", default=2048) + """model""" + parser.add_argument("--beta_start", default=0.0001) + parser.add_argument("--beta_end", default=0.02) + parser.add_argument("--schedule_type", default="linear") + parser.add_argument("--time_num", default=1000) - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') + # params + parser.add_argument("--attention", default=True) + parser.add_argument("--dropout", default=0.1) + parser.add_argument("--embed_dim", type=int, default=64) + parser.add_argument("--loss_type", default="mse") + parser.add_argument("--model_mean_type", default="eps") + parser.add_argument("--model_var_type", default="fixedsmall") + parser.add_argument("--model", default="", required=True, help="path to model (to continue training)") - parser.add_argument('--model', default='',required=True, help="path to model (to continue training)") + """eval""" - '''eval''' + parser.add_argument("--eval_path", default="") - parser.add_argument('--eval_path', - default='') + parser.add_argument("--manualSeed", default=42, type=int, help="random seed") - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') - - parser.add_argument('--gpu', type=int, default=0, metavar='S', help='gpu id (default: 0)') + parser.add_argument("--gpu", type=int, default=0, metavar="S", help="gpu id (default: 0)") opt = parser.parse_args() @@ -609,7 +643,9 @@ def parse_args(): opt.cuda = False return opt -if __name__ == '__main__': + + +if __name__ == "__main__": opt = parse_args() set_seed(opt) diff --git a/train_completion.py b/train_completion.py index 8d81818..e5a5afc 100644 --- a/train_completion.py +++ b/train_completion.py @@ -1,20 +1,23 @@ +import argparse + +import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim import torch.utils.data - -import argparse from torch.distributions import Normal -from utils.file_utils import * -from utils.visualize import * -from model.pvcnn_completion import PVCNN2Base -import torch.distributed as dist from datasets.shapenet_data_pc import ShapeNet15kPointClouds from datasets.shapenet_data_sv import ShapeNet_Multiview_Points -''' +from model.pvcnn_completion import PVCNN2Base +from utils.file_utils import * +from utils.visualize import * + +""" some utils -''' +""" + + def rotation_matrix(axis, theta): """ Return the rotation matrix associated with counterclockwise rotation about @@ -26,29 +29,36 @@ def rotation_matrix(axis, theta): b, c, d = -axis * np.sin(theta / 2.0) aa, bb, cc, dd = a * a, b * b, c * c, d * d bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + return np.array( + [ + [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc], + ] + ) + def rotate(vertices, faces): - ''' + """ vertices: [numpoints, 3] - ''' + """ M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() K = rotation_matrix([0, 0, 1], np.pi).transpose() - v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]] return v, f + def norm(v, f): - v = (v - v.min())/(v.max() - v.min()) - 0.5 + v = (v - v.min()) / (v.max() - v.min()) - 0.5 return v, f + def getGradNorm(net): - pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) - gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + pNorm = torch.sqrt(sum(torch.sum(p**2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad**2) for p in net.parameters())) return pNorm, gradNorm @@ -57,21 +67,24 @@ def weights_init(m): xavier initialization """ classname = m.__class__.__name__ - if classname.find('Conv') != -1 and m.weight is not None: + if classname.find("Conv") != -1 and m.weight is not None: torch.nn.init.xavier_normal_(m.weight) - elif classname.find('BatchNorm') != -1: + elif classname.find("BatchNorm") != -1: m.weight.data.normal_() m.bias.data.fill_(0) -''' + +""" models -''' +""" + + def normal_kl(mean1, logvar1, mean2, logvar2): """ KL divergence between normal distributions parameterized by mean and log-variance. """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2)) + def discretized_gaussian_log_likelihood(x, *, means, log_scales): # Assumes data is integers [0, 1] @@ -82,19 +95,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 0.5) cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) + min_in = inv_stdv * (centered_x - 0.5) cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + x < 0.001, + log_cdf_plus, + torch.where( + x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12)) + ), + ) assert log_probs.shape == x.shape return log_probs + class GaussianDiffusion: def __init__(self, betas, loss_type, model_mean_type, model_var_type, sv_points): self.loss_type = loss_type @@ -103,15 +120,15 @@ class GaussianDiffusion: assert isinstance(betas, np.ndarray) self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.sv_points = sv_points # initialize twice the actual length so we can keep running for eval # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - alphas = 1. - betas + alphas = 1.0 - betas alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float() self.betas = torch.from_numpy(betas).float() self.alphas_cumprod = alphas_cumprod.float() @@ -119,21 +136,23 @@ class GaussianDiffusion: # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float() betas = torch.from_numpy(betas).float() alphas = torch.from_numpy(alphas).float() # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.posterior_variance = posterior_variance # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)) + ) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) @staticmethod def _extract(a, t, x_shape): @@ -141,17 +160,15 @@ class GaussianDiffusion: Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ - bs, = t.shape + (bs,) = t.shape assert x_shape[0] == bs out = torch.gather(a, 0, t) assert out.shape == torch.Size([bs]) return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - def q_mean_variance(self, x_start, t): mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) return mean, variance, log_variance @@ -163,54 +180,59 @@ class GaussianDiffusion: noise = torch.randn(x_start.shape, device=x_start.device) assert noise.shape == x_start.shape return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise ) - def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) + posterior_log_variance_clipped = self._extract( + self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): + model_output = denoise_fn(data, t)[:, :, self.sv_points :] - model_output = denoise_fn(data, t)[:,:,self.sv_points:] - - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: + if self.model_var_type in ["fixedsmall", "fixedlarge"]: # below: only log_variance is used in the KL computations model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + "fixedlarge": ( + self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device), + ), + "fixedsmall": ( + self.posterior_variance.to(data.device), + self.posterior_log_variance_clipped.to(data.device), + ), }[self.model_var_type] model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(model_output) model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(model_output) else: raise NotImplementedError(self.model_var_type) - if self.model_mean_type == 'eps': - x_recon = self._predict_xstart_from_eps(data[:,:,self.sv_points:], t=t, eps=model_output) + if self.model_mean_type == "eps": + x_recon = self._predict_xstart_from_eps(data[:, :, self.sv_points :], t=t, eps=model_output) - - model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:,:,self.sv_points:], t=t) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data[:, :, self.sv_points :], t=t) else: raise NotImplementedError(self.loss_type) - assert model_mean.shape == x_recon.shape assert model_variance.shape == model_log_variance.shape if return_pred_xstart: @@ -221,30 +243,31 @@ class GaussianDiffusion: def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t + - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps ) - ''' samples ''' + """ samples """ def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): """ Sample from the model """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) noise = noise_fn(size=model_mean.shape, dtype=model_mean.dtype, device=model_mean.device) # no noise when t == 0 nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(model_mean.shape) - 1)) sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise - sample = torch.cat([data[:, :, :self.sv_points], sample], dim=-1) + sample = torch.cat([data[:, :, : self.sv_points], sample], dim=-1) return (sample, pred_xstart) if return_pred_xstart else sample - - def p_sample_loop(self, partial_x, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): + def p_sample_loop( + self, partial_x, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False + ): """ Generate samples keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps @@ -255,14 +278,21 @@ class GaussianDiffusion: img_t = torch.cat([partial_x, noise_fn(size=shape, dtype=torch.float, device=device)], dim=-1) for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ) - assert img_t[:,:,self.sv_points:].shape == shape + assert img_t[:, :, self.sv_points :].shape == shape return img_t - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): + def p_sample_loop_trajectory( + self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False + ): """ Generate samples, returning intermediate images Useful for visualizing how denoised images evolve over time @@ -272,31 +302,38 @@ class GaussianDiffusion: """ assert isinstance(shape, (tuple, list)) - total_steps = self.num_timesteps if not keep_running else len(self.betas) + total_steps = self.num_timesteps if not keep_running else len(self.betas) img_t = noise_fn(size=shape, dtype=torch.float, device=device) imgs = [img_t] - for t in reversed(range(0,total_steps)): - + for t in reversed(range(0, total_steps)): t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ) + if t % freq == 0 or t == total_steps - 1: imgs.append(img_t) assert imgs[-1].shape == shape return imgs - '''losses''' + """losses""" def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start[:,:,self.sv_points:], x_t=data_t[:,:,self.sv_points:], t=t) + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=data_start[:, :, self.sv_points :], x_t=data_t[:, :, self.sv_points :], t=t + ) model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.) + kl = kl.mean(dim=list(range(1, len(model_mean.shape)))) / np.log(2.0) return (kl, pred_xstart) if return_pred_xstart else kl @@ -308,66 +345,87 @@ class GaussianDiffusion: assert t.shape == torch.Size([B]) if noise is None: - noise = torch.randn(data_start[:,:,self.sv_points:].shape, dtype=data_start.dtype, device=data_start.device) + noise = torch.randn( + data_start[:, :, self.sv_points :].shape, dtype=data_start.dtype, device=data_start.device + ) - data_t = self.q_sample(x_start=data_start[:,:,self.sv_points:], t=t, noise=noise) + data_t = self.q_sample(x_start=data_start[:, :, self.sv_points :], t=t, noise=noise) - if self.loss_type == 'mse': + if self.loss_type == "mse": # predict the noise instead of x_start. seems to be weighted naturally like SNR - eps_recon = denoise_fn(torch.cat([data_start[:,:,:self.sv_points], data_t], dim=-1), t)[:,:,self.sv_points:] + eps_recon = denoise_fn(torch.cat([data_start[:, :, : self.sv_points], data_t], dim=-1), t)[ + :, :, self.sv_points : + ] - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': + losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == "kl": losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) + denoise_fn=denoise_fn, + data_start=data_start, + data_t=data_t, + t=t, + clip_denoised=False, + return_pred_xstart=False, + ) else: raise NotImplementedError(self.loss_type) assert losses.shape == torch.Size([B]) return losses - '''debug''' + """debug""" def _prior_bpd(self, x_start): - with torch.no_grad(): B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + kl_prior = normal_kl( + mean1=qt_mean, + logvar1=qt_log_variance, + mean2=torch.tensor([0.0]).to(qt_mean), + logvar2=torch.tensor([0.0]).to(qt_log_variance), + ) assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0) def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - with torch.no_grad(): B, T = x_start.shape[0], self.num_timesteps - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) for t in reversed(range(T)): - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) # Calculate VLB term at the current timestep - data_t = torch.cat([x_start[:, :, :self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points:], t=t_b)], dim=-1) + data_t = torch.cat( + [x_start[:, :, : self.sv_points], self.q_sample(x_start=x_start[:, :, self.sv_points :], t=t_b)], + dim=-1, + ) new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=data_t, t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) + denoise_fn, + data_start=x_start, + data_t=data_t, + t=t_b, + clip_denoised=clip_denoised, + return_pred_xstart=True, + ) # MSE for progressive prediction loss - assert pred_xstart.shape == x_start[:, :, self.sv_points:].shape - new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points:]) ** 2).mean(dim=list(range(1, len(pred_xstart.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + assert pred_xstart.shape == x_start[:, :, self.sv_points :].shape + new_mse_b = ((pred_xstart - x_start[:, :, self.sv_points :]) ** 2).mean( + dim=list(range(1, len(pred_xstart.shape))) + ) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float() vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) - prior_bpd_b = self._prior_bpd(x_start[:,:,self.sv_points:]) + prior_bpd_b = self._prior_bpd(x_start[:, :, self.sv_points :]) total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + assert vals_bt_.shape == mse_bt_.shape == torch.Size( + [B, T] + ) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() @@ -385,39 +443,53 @@ class PVCNN2(PVCNN2Base): ((128, 128, 64), (64, 2, 32)), ] - def __init__(self, num_classes, sv_points, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): + def __init__( + self, + num_classes, + sv_points, + embed_dim, + use_att, + dropout, + extra_feature_channels=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + ): super().__init__( - num_classes=num_classes, sv_points=sv_points, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + num_classes=num_classes, + sv_points=sv_points, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str): super(Model, self).__init__() self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type, args.svpoints) - self.model = PVCNN2(num_classes=args.nc, sv_points=args.svpoints, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) + self.model = PVCNN2( + num_classes=args.nc, + sv_points=args.svpoints, + embed_dim=args.embed_dim, + use_att=args.attention, + dropout=args.dropout, + extra_feature_channels=0, + ) def prior_kl(self, x0): return self.diffusion._prior_bpd(x0) def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt} def _denoise(self, data, t): - B, D,N= data.shape + B, D, N = data.shape assert data.dtype == torch.float assert t.shape == torch.Size([B]) and t.dtype == torch.int64 @@ -430,20 +502,22 @@ class Model(nn.Module): t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises) - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises) assert losses.shape == t.shape == torch.Size([B]) return losses - def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(partial_x, self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) - + def gen_samples(self, partial_x, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False): + return self.diffusion.p_sample_loop( + partial_x, + self._denoise, + shape=shape, + device=device, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running, + ) def train(self): self.model.train() @@ -456,20 +530,17 @@ class Model(nn.Module): def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': + if schedule_type == "linear": betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - + elif schedule_type == "warm0.1": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.1) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - + elif schedule_type == "warm0.2": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.2) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - + elif schedule_type == "warm0.5": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.5) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) @@ -479,18 +550,25 @@ def get_betas(schedule_type, b_start, b_end, time_num): def get_dataset(dataroot_pc, dataroot_sv, npoints, svpoints, category): - tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot_pc, - categories=[category], split='train', + tr_dataset = ShapeNet15kPointClouds( + root_dir=dataroot_pc, + categories=[category], + split="train", tr_sample_size=npoints, te_sample_size=npoints, - scale=1., + scale=1.0, normalize_per_shape=False, normalize_std_per_axis=False, - random_subsample=True) - tr_dataset = ShapeNet_Multiview_Points(root_pc=dataroot_pc, root_views=dataroot_sv, - cache=os.path.join(dataroot_pc, '../cache'), split='train', + random_subsample=True, + ) + tr_dataset = ShapeNet_Multiview_Points( + root_pc=dataroot_pc, + root_views=dataroot_sv, + cache=os.path.join(dataroot_pc, "../cache"), + split="train", categories=[category], - npoints=npoints, sv_samples=svpoints, + npoints=npoints, + sv_samples=svpoints, all_points_mean=tr_dataset.all_points_mean, all_points_std=tr_dataset.all_points_std, ) @@ -498,18 +576,13 @@ def get_dataset(dataroot_pc, dataroot_sv, npoints, svpoints, category): def get_dataloader(opt, train_dataset, test_dataset=None): - - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=opt.world_size, - rank=opt.rank + train_dataset, num_replicas=opt.world_size, rank=opt.rank ) if test_dataset is not None: test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=opt.world_size, - rank=opt.rank + test_dataset, num_replicas=opt.world_size, rank=opt.rank ) else: test_sampler = None @@ -517,12 +590,24 @@ def get_dataloader(opt, train_dataset, test_dataset=None): train_sampler = None test_sampler = None - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, - shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.bs, + sampler=train_sampler, + shuffle=train_sampler is None, + num_workers=int(opt.workers), + drop_last=True, + ) if test_dataset is not None: - test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, - shuffle=False, num_workers=int(opt.workers), drop_last=False) + test_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.bs, + sampler=test_sampler, + shuffle=False, + num_workers=int(opt.workers), + drop_last=False, + ) else: test_dataloader = None @@ -530,58 +615,57 @@ def get_dataloader(opt, train_dataset, test_dataset=None): def train(gpu, opt, output_dir, noises_init): - set_seed(opt) logger = setup_logging(output_dir) - if opt.distribution_type == 'multi': - should_diag = gpu==0 + if opt.distribution_type == "multi": + should_diag = gpu == 0 else: should_diag = True if should_diag: - outf_syn, = setup_output_subdirs(output_dir, 'syn') + (outf_syn,) = setup_output_subdirs(output_dir, "syn") - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": if opt.dist_url == "env://" and opt.rank == -1: opt.rank = int(os.environ["RANK"]) - base_rank = opt.rank * opt.ngpus_per_node + base_rank = opt.rank * opt.ngpus_per_node opt.rank = base_rank + gpu - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, - world_size=opt.world_size, rank=opt.rank) + dist.init_process_group( + backend=opt.dist_backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank + ) opt.bs = int(opt.bs / opt.ngpus_per_node) opt.workers = 0 - opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) - - ''' data ''' - train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category) + """ data """ + train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints, opt.category) dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) - - ''' + """ create networks - ''' + """ betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + if opt.distribution_type == "multi": # Multiple processes, single GPU per process + def _transform_(m): - return nn.parallel.DistributedDataParallel( - m, device_ids=[gpu], output_device=gpu) + return nn.parallel.DistributedDataParallel(m, device_ids=[gpu], output_device=gpu) torch.cuda.set_device(gpu) model.cuda(gpu) model.multi_gpu_wrapper(_transform_) + elif opt.distribution_type == "single": - elif opt.distribution_type == 'single': def _transform_(m): return nn.parallel.DataParallel(m) + model = model.cuda() model.multi_gpu_wrapper(_transform_) @@ -589,49 +673,47 @@ def train(gpu, opt, output_dir, noises_init): torch.cuda.set_device(gpu) model = model.cuda(gpu) else: - raise ValueError('distribution_type = multi | single | None') + raise ValueError("distribution_type = multi | single | None") if should_diag: logger.info(opt) - optimizer= optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999)) + optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999)) lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.lr_gamma) - if opt.model != '': + if opt.model != "": ckpt = torch.load(opt.model) - model.load_state_dict(ckpt['model_state']) - optimizer.load_state_dict(ckpt['optimizer_state']) + model.load_state_dict(ckpt["model_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) - if opt.model != '': - start_epoch = torch.load(opt.model)['epoch'] + 1 + if opt.model != "": + start_epoch = torch.load(opt.model)["epoch"] + 1 else: start_epoch = 0 - for epoch in range(start_epoch, opt.niter): - - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": train_sampler.set_epoch(epoch) lr_scheduler.step(epoch) for i, data in enumerate(dataloader): - randind = np.random.choice(20) #20 views - x = data['train_points'].transpose(1,2) - sv_x = data['sv_points'][:,randind].transpose(1,2) + randind = np.random.choice(20) # 20 views + x = data["train_points"].transpose(1, 2) + sv_x = data["sv_points"][:, randind].transpose(1, 2) - sv_x[:,:,opt.svpoints:] = x[:,:,opt.svpoints:] - noises_batch = noises_init[data['idx']].transpose(1,2) + sv_x[:, :, opt.svpoints :] = x[:, :, opt.svpoints :] + noises_batch = noises_init[data["idx"]].transpose(1, 2) - ''' + """ train diffusion - ''' + """ - if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + if opt.distribution_type == "multi" or (opt.distribution_type is None and gpu is not None): sv_x = sv_x.cuda(gpu) noises_batch = noises_batch.cuda(gpu) - elif opt.distribution_type == 'single': + elif opt.distribution_type == "single": sv_x = sv_x.cuda() noises_batch = noises_batch.cuda() @@ -645,100 +727,107 @@ def train(gpu, opt, output_dir, noises_init): optimizer.step() - if i % opt.print_freq == 0 and should_diag: - - logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' - 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' - .format( - epoch, opt.niter, i, len(dataloader),loss.item(), - netpNorm, netgradNorm, - )) - + logger.info( + "[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, " + "netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ".format( + epoch, + opt.niter, + i, + len(dataloader), + loss.item(), + netpNorm, + netgradNorm, + ) + ) if (epoch + 1) % opt.diagIter == 0 and should_diag: - - logger.info('Diagnosis:') + logger.info("Diagnosis:") x_range = [x.min().item(), x.max().item()] kl_stats = model.all_kl(sv_x) - logger.info(' [{:>3d}/{:>3d}] ' - 'x_range: [{:>10.4f}, {:>10.4f}], ' - 'total_bpd_b: {:>10.4f}, ' - 'terms_bpd: {:>10.4f}, ' - 'prior_bpd_b: {:>10.4f} ' - 'mse_bt: {:>10.4f} ' - .format( - epoch, opt.niter, - *x_range, - kl_stats['total_bpd_b'].item(), - kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() - )) - - + logger.info( + " [{:>3d}/{:>3d}] " + "x_range: [{:>10.4f}, {:>10.4f}], " + "total_bpd_b: {:>10.4f}, " + "terms_bpd: {:>10.4f}, " + "prior_bpd_b: {:>10.4f} " + "mse_bt: {:>10.4f} ".format( + epoch, + opt.niter, + *x_range, + kl_stats["total_bpd_b"].item(), + kl_stats["terms_bpd"].item(), + kl_stats["prior_bpd_b"].item(), + kl_stats["mse_bt"].item(), + ) + ) if (epoch + 1) % opt.vizIter == 0 and should_diag: - logger.info('Generation: eval') + logger.info("Generation: eval") model.eval() m, s = train_dataset.all_points_mean.reshape(1, -1), train_dataset.all_points_std.reshape(1, -1) with torch.no_grad(): - - x_gen_eval = model.gen_samples(sv_x[:,:,:opt.svpoints], sv_x[:,:,opt.svpoints:].shape, sv_x.device, clip_denoised=False).detach().cpu() - + x_gen_eval = ( + model.gen_samples( + sv_x[:, :, : opt.svpoints], sv_x[:, :, opt.svpoints :].shape, sv_x.device, clip_denoised=False + ) + .detach() + .cpu() + ) gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] - logger.info(' [{:>3d}/{:>3d}] ' - 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' - 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' - .format( - epoch, opt.niter, - *gen_eval_range, *gen_stats, - )) + logger.info( + " [{:>3d}/{:>3d}] " + "eval_gen_range: [{:>10.4f}, {:>10.4f}] " + "eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ".format( + epoch, + opt.niter, + *gen_eval_range, + *gen_stats, + ) + ) - export_to_pc_batch('%s/epoch_%03d_samples_eval' % (outf_syn, epoch), - (x_gen_eval.transpose(1, 2)*s+m).numpy()*3) + export_to_pc_batch( + "%s/epoch_%03d_samples_eval" % (outf_syn, epoch), (x_gen_eval.transpose(1, 2) * s + m).numpy() * 3 + ) - export_to_pc_batch('%s/epoch_%03d_ground_truth' % (outf_syn, epoch), - (sv_x.transpose(1, 2).detach().cpu()*s+m).numpy()*3) - - export_to_pc_batch('%s/epoch_%03d_partial' % (outf_syn, epoch), - (sv_x[:,:,:opt.svpoints].transpose(1, 2).detach().cpu()*s+m).numpy()*3) + export_to_pc_batch( + "%s/epoch_%03d_ground_truth" % (outf_syn, epoch), + (sv_x.transpose(1, 2).detach().cpu() * s + m).numpy() * 3, + ) + export_to_pc_batch( + "%s/epoch_%03d_partial" % (outf_syn, epoch), + (sv_x[:, :, : opt.svpoints].transpose(1, 2).detach().cpu() * s + m).numpy() * 3, + ) model.train() - - - - - - if (epoch + 1) % opt.saveIter == 0: - if should_diag: - - save_dict = { - 'epoch': epoch, - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + "epoch": epoch, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), } - torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + torch.save(save_dict, "%s/epoch_%d.pth" % (output_dir, epoch)) - - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": dist.barrier() - map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + map_location = {"cuda:%d" % 0: "cuda:%d" % gpu} model.load_state_dict( - torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + torch.load("%s/epoch_%d.pth" % (output_dir, epoch), map_location=map_location)["model_state"] + ) dist.destroy_process_group() + def main(): opt = parse_args() @@ -747,15 +836,15 @@ def main(): output_dir = get_output_dir(dir_id, exp_id) copy_source(__file__, output_dir) - ''' workaround ''' + """ workaround """ - train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints,opt.category) - noises_init = torch.randn(len(train_dataset), opt.npoints-opt.svpoints, opt.nc) + train_dataset = get_dataset(opt.dataroot_pc, opt.dataroot_sv, opt.npoints, opt.svpoints, opt.category) + noises_init = torch.randn(len(train_dataset), opt.npoints - opt.svpoints, opt.nc) if opt.dist_url == "env://" and opt.world_size == -1: opt.world_size = int(os.environ["WORLD_SIZE"]) - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": opt.ngpus_per_node = torch.cuda.device_count() opt.world_size = opt.ngpus_per_node * opt.world_size mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) @@ -763,73 +852,71 @@ def main(): train(opt.gpu, opt, output_dir, noises_init) - def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot_pc', default='ShapeNetCore.v2.PC15k/') - parser.add_argument('--dataroot_sv', default='GenReData/') - parser.add_argument('--category', default='chair') + parser.add_argument("--dataroot_pc", default="ShapeNetCore.v2.PC15k/") + parser.add_argument("--dataroot_sv", default="GenReData/") + parser.add_argument("--category", default="chair") - parser.add_argument('--bs', type=int, default=48, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + parser.add_argument("--bs", type=int, default=48, help="input batch size") + parser.add_argument("--workers", type=int, default=16, help="workers") + parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for") - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - parser.add_argument('--svpoints', default=200) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) + parser.add_argument("--nc", default=3) + parser.add_argument("--npoints", default=2048) + parser.add_argument("--svpoints", default=200) + """model""" + parser.add_argument("--beta_start", default=0.0001) + parser.add_argument("--beta_end", default=0.02) + parser.add_argument("--schedule_type", default="linear") + parser.add_argument("--time_num", default=1000) - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') + # params + parser.add_argument("--attention", default=True) + parser.add_argument("--dropout", default=0.1) + parser.add_argument("--embed_dim", type=int, default=64) + parser.add_argument("--loss_type", default="mse") + parser.add_argument("--model_mean_type", default="eps") + parser.add_argument("--model_var_type", default="fixedsmall") - parser.add_argument('--lr', type=float, default=2e-4, help='learning rate for E, default=0.0002') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') - parser.add_argument('--decay', type=float, default=0, help='weight decay for EBM') - parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') - parser.add_argument('--lr_gamma', type=float, default=0.998, help='lr decay for EBM') + parser.add_argument("--lr", type=float, default=2e-4, help="learning rate for E, default=0.0002") + parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5") + parser.add_argument("--decay", type=float, default=0, help="weight decay for EBM") + parser.add_argument("--grad_clip", type=float, default=None, help="weight decay for EBM") + parser.add_argument("--lr_gamma", type=float, default=0.998, help="lr decay for EBM") - parser.add_argument('--model', default='', help="path to model (to continue training)") + parser.add_argument("--model", default="", help="path to model (to continue training)") + """distributed""" + parser.add_argument("--world_size", default=1, type=int, help="Number of distributed nodes.") + parser.add_argument( + "--dist_url", default="tcp://127.0.0.1:9991", type=str, help="url used to set up distributed training" + ) + parser.add_argument("--dist_backend", default="nccl", type=str, help="distributed backend") + parser.add_argument( + "--distribution_type", + default="single", + choices=["multi", "single", None], + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", + ) + parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") + parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.") - '''distributed''' - parser.add_argument('--world_size', default=1, type=int, - help='Number of distributed nodes.') - parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, - help='url used to set up distributed training') - parser.add_argument('--dist_backend', default='nccl', type=str, - help='distributed backend') - parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], - help='Use multi-processing distributed training to launch ' - 'N processes per node, which has N GPUs. This is the ' - 'fastest way to use PyTorch for either single node or ' - 'multi node data parallel training') - parser.add_argument('--rank', default=0, type=int, - help='node rank for distributed training') - parser.add_argument('--gpu', default=None, type=int, - help='GPU id to use. None means using all available GPUs.') - - '''eval''' - parser.add_argument('--saveIter', default=100, help='unit: epoch') - parser.add_argument('--diagIter', default=50, help='unit: epoch') - parser.add_argument('--vizIter', default=50, help='unit: epoch') - parser.add_argument('--print_freq', default=50, help='unit: iter') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + """eval""" + parser.add_argument("--saveIter", default=100, help="unit: epoch") + parser.add_argument("--diagIter", default=50, help="unit: epoch") + parser.add_argument("--vizIter", default=50, help="unit: epoch") + parser.add_argument("--print_freq", default=50, help="unit: iter") + parser.add_argument("--manualSeed", default=42, type=int, help="random seed") opt = parser.parse_args() return opt -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/train_generation.py b/train_generation.py index 83d39ad..e248334 100644 --- a/train_generation.py +++ b/train_generation.py @@ -1,20 +1,23 @@ +import argparse + +import numpy as np +import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim import torch.utils.data - -import argparse from torch.distributions import Normal +from datasets.shapenet_data_pc import ShapeNet15kPointClouds +from model.pvcnn_generation import PVCNN2Base from utils.file_utils import * from utils.visualize import * -from model.pvcnn_generation import PVCNN2Base -import torch.distributed as dist -from datasets.shapenet_data_pc import ShapeNet15kPointClouds -''' +""" some utils -''' +""" + + def rotation_matrix(axis, theta): """ Return the rotation matrix associated with counterclockwise rotation about @@ -26,29 +29,36 @@ def rotation_matrix(axis, theta): b, c, d = -axis * np.sin(theta / 2.0) aa, bb, cc, dd = a * a, b * b, c * c, d * d bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + return np.array( + [ + [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc], + ] + ) + def rotate(vertices, faces): - ''' + """ vertices: [numpoints, 3] - ''' + """ M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() K = rotation_matrix([0, 0, 1], np.pi).transpose() - v, f = vertices[:,[1,2,0]].dot(M).dot(N).dot(K), faces[:,[1,2,0]] + v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]] return v, f + def norm(v, f): - v = (v - v.min())/(v.max() - v.min()) - 0.5 + v = (v - v.min()) / (v.max() - v.min()) - 0.5 return v, f + def getGradNorm(net): - pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) - gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + pNorm = torch.sqrt(sum(torch.sum(p**2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad**2) for p in net.parameters())) return pNorm, gradNorm @@ -57,21 +67,24 @@ def weights_init(m): xavier initialization """ classname = m.__class__.__name__ - if classname.find('Conv') != -1 and m.weight is not None: + if classname.find("Conv") != -1 and m.weight is not None: torch.nn.init.xavier_normal_(m.weight) - elif classname.find('BatchNorm') != -1: + elif classname.find("BatchNorm") != -1: m.weight.data.normal_() m.bias.data.fill_(0) -''' + +""" models -''' +""" + + def normal_kl(mean1, logvar1, mean2, logvar2): """ KL divergence between normal distributions parameterized by mean and log-variance. """ - return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) - + (mean1 - mean2)**2 * torch.exp(-logvar2)) + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + (mean1 - mean2) ** 2 * torch.exp(-logvar2)) + def discretized_gaussian_log_likelihood(x, *, means, log_scales): # Assumes data is integers [0, 1] @@ -82,36 +95,40 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales): inv_stdv = torch.exp(-log_scales) plus_in = inv_stdv * (centered_x + 0.5) cdf_plus = px0.cdf(plus_in) - min_in = inv_stdv * (centered_x - .5) + min_in = inv_stdv * (centered_x - 0.5) cdf_min = px0.cdf(min_in) - log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus)*1e-12)) - log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min)*1e-12)) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1.0 - cdf_min, torch.ones_like(cdf_min) * 1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = torch.where( - x < 0.001, log_cdf_plus, - torch.where(x > 0.999, log_one_minus_cdf_min, - torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta)*1e-12)))) + x < 0.001, + log_cdf_plus, + torch.where( + x > 0.999, log_one_minus_cdf_min, torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12)) + ), + ) assert log_probs.shape == x.shape return log_probs + class GaussianDiffusion: - def __init__(self,betas, loss_type, model_mean_type, model_var_type): + def __init__(self, betas, loss_type, model_mean_type, model_var_type): self.loss_type = loss_type self.model_mean_type = model_mean_type self.model_var_type = model_var_type assert isinstance(betas, np.ndarray) self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy assert (betas > 0).all() and (betas <= 1).all() - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) # initialize twice the actual length so we can keep running for eval # betas = np.concatenate([betas, np.full_like(betas[:int(0.2*len(betas))], betas[-1])]) - alphas = 1. - betas + alphas = 1.0 - betas alphas_cumprod = torch.from_numpy(np.cumprod(alphas, axis=0)).float() - alphas_cumprod_prev = torch.from_numpy(np.append(1., alphas_cumprod[:-1])).float() + alphas_cumprod_prev = torch.from_numpy(np.append(1.0, alphas_cumprod[:-1])).float() self.betas = torch.from_numpy(betas).float() self.alphas_cumprod = alphas_cumprod.float() @@ -119,21 +136,23 @@ class GaussianDiffusion: # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).float() - self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).float() - self.log_one_minus_alphas_cumprod = torch.log(1. - alphas_cumprod).float() - self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() - self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float() + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).float() + self.log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod).float() + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod).float() + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1).float() betas = torch.from_numpy(betas).float() alphas = torch.from_numpy(alphas).float() # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.posterior_variance = posterior_variance # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance))) - self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) - self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)) + ) + self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) @staticmethod def _extract(a, t, x_shape): @@ -141,17 +160,15 @@ class GaussianDiffusion: Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ - bs, = t.shape + (bs,) = t.shape assert x_shape[0] == bs out = torch.gather(a, 0, t) assert out.shape == torch.Size([bs]) return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1])) - - def q_mean_variance(self, x_start, t): mean = self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start - variance = self._extract(1. - self.alphas_cumprod.to(x_start.device), t, x_start.shape) + variance = self._extract(1.0 - self.alphas_cumprod.to(x_start.device), t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) return mean, variance, log_variance @@ -163,56 +180,62 @@ class GaussianDiffusion: noise = torch.randn(x_start.shape, device=x_start.device) assert noise.shape == x_start.shape return ( - self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + - self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise + self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise ) - def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( - self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + - self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t + self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) - posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) - assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == - x_start.shape[0]) + posterior_log_variance_clipped = self._extract( + self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, denoise_fn, data, t, clip_denoised: bool, return_pred_xstart: bool): - model_output = denoise_fn(data, t) - - if self.model_var_type in ['fixedsmall', 'fixedlarge']: + if self.model_var_type in ["fixedsmall", "fixedlarge"]: # below: only log_variance is used in the KL computations model_variance, model_log_variance = { # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood - 'fixedlarge': (self.betas.to(data.device), - torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)), - 'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)), + "fixedlarge": ( + self.betas.to(data.device), + torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device), + ), + "fixedsmall": ( + self.posterior_variance.to(data.device), + self.posterior_log_variance_clipped.to(data.device), + ), }[self.model_var_type] model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data) model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data) else: raise NotImplementedError(self.model_var_type) - if self.model_mean_type == 'eps': + if self.model_mean_type == "eps": x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output) if clip_denoised: - x_recon = torch.clamp(x_recon, -.5, .5) + x_recon = torch.clamp(x_recon, -0.5, 0.5) model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t) else: raise NotImplementedError(self.loss_type) - assert model_mean.shape == x_recon.shape == data.shape assert model_variance.shape == model_log_variance.shape == data.shape if return_pred_xstart: @@ -223,18 +246,19 @@ class GaussianDiffusion: def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps + self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t + - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps ) - ''' samples ''' + """ samples """ def p_sample(self, denoise_fn, data, t, noise_fn, clip_denoised=False, return_pred_xstart=False): """ Sample from the model """ - model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, clip_denoised=clip_denoised, - return_pred_xstart=True) + model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( + denoise_fn, data=data, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device) assert noise.shape == data.shape # no noise when t == 0 @@ -244,9 +268,7 @@ class GaussianDiffusion: assert sample.shape == pred_xstart.shape return (sample, pred_xstart) if return_pred_xstart else sample - - def p_sample_loop(self, denoise_fn, shape, device, - noise_fn=torch.randn, clip_denoised=True, keep_running=False): + def p_sample_loop(self, denoise_fn, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False): """ Generate samples keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps @@ -257,14 +279,21 @@ class GaussianDiffusion: img_t = noise_fn(size=shape, dtype=torch.float, device=device) for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))): t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, return_pred_xstart=False) + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ) assert img_t.shape == shape return img_t - def p_sample_loop_trajectory(self, denoise_fn, shape, device, freq, - noise_fn=torch.randn,clip_denoised=True, keep_running=False): + def p_sample_loop_trajectory( + self, denoise_fn, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False + ): """ Generate samples, returning intermediate images Useful for visualizing how denoised images evolve over time @@ -274,30 +303,35 @@ class GaussianDiffusion: """ assert isinstance(shape, (tuple, list)) - total_steps = self.num_timesteps if not keep_running else len(self.betas) + total_steps = self.num_timesteps if not keep_running else len(self.betas) img_t = noise_fn(size=shape, dtype=torch.float, device=device) imgs = [img_t] - for t in reversed(range(0,total_steps)): - + for t in reversed(range(0, total_steps)): t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t) - img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t, t=t_, noise_fn=noise_fn, - clip_denoised=clip_denoised, - return_pred_xstart=False) - if t % freq == 0 or t == total_steps-1: + img_t = self.p_sample( + denoise_fn=denoise_fn, + data=img_t, + t=t_, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + return_pred_xstart=False, + ) + if t % freq == 0 or t == total_steps - 1: imgs.append(img_t) assert imgs[-1].shape == shape return imgs - '''losses''' + """losses""" def _vb_terms_bpd(self, denoise_fn, data_start, data_t, t, clip_denoised: bool, return_pred_xstart: bool): true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=data_start, x_t=data_t, t=t) model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance( - denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True) + denoise_fn, data=data_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True + ) kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance) - kl = kl.mean(dim=list(range(1, len(data_start.shape)))) / np.log(2.) + kl = kl.mean(dim=list(range(1, len(data_start.shape)))) / np.log(2.0) return (kl, pred_xstart) if return_pred_xstart else kl @@ -314,63 +348,75 @@ class GaussianDiffusion: data_t = self.q_sample(x_start=data_start, t=t, noise=noise) - if self.loss_type == 'mse': + if self.loss_type == "mse": # predict the noise instead of x_start. seems to be weighted naturally like SNR eps_recon = denoise_fn(data_t, t) assert data_t.shape == data_start.shape assert eps_recon.shape == torch.Size([B, D, N]) assert eps_recon.shape == data_start.shape - losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape)))) - elif self.loss_type == 'kl': + losses = ((noise - eps_recon) ** 2).mean(dim=list(range(1, len(data_start.shape)))) + elif self.loss_type == "kl": losses = self._vb_terms_bpd( - denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, clip_denoised=False, - return_pred_xstart=False) + denoise_fn=denoise_fn, + data_start=data_start, + data_t=data_t, + t=t, + clip_denoised=False, + return_pred_xstart=False, + ) else: raise NotImplementedError(self.loss_type) assert losses.shape == torch.Size([B]) return losses - '''debug''' + """debug""" def _prior_bpd(self, x_start): - with torch.no_grad(): B, T = x_start.shape[0], self.num_timesteps - t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T-1) + t_ = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(T - 1) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t=t_) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, - mean2=torch.tensor([0.]).to(qt_mean), logvar2=torch.tensor([0.]).to(qt_log_variance)) + kl_prior = normal_kl( + mean1=qt_mean, + logvar1=qt_log_variance, + mean2=torch.tensor([0.0]).to(qt_mean), + logvar2=torch.tensor([0.0]).to(qt_log_variance), + ) assert kl_prior.shape == x_start.shape - return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.) + return kl_prior.mean(dim=list(range(1, len(kl_prior.shape)))) / np.log(2.0) def calc_bpd_loop(self, denoise_fn, x_start, clip_denoised=True): - with torch.no_grad(): B, T = x_start.shape[0], self.num_timesteps - vals_bt_, mse_bt_= torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) + vals_bt_, mse_bt_ = torch.zeros([B, T], device=x_start.device), torch.zeros([B, T], device=x_start.device) for t in reversed(range(T)): - t_b = torch.empty(B, dtype=torch.int64, device=x_start.device).fill_(t) # Calculate VLB term at the current timestep new_vals_b, pred_xstart = self._vb_terms_bpd( - denoise_fn, data_start=x_start, data_t=self.q_sample(x_start=x_start, t=t_b), t=t_b, - clip_denoised=clip_denoised, return_pred_xstart=True) + denoise_fn, + data_start=x_start, + data_t=self.q_sample(x_start=x_start, t=t_b), + t=t_b, + clip_denoised=clip_denoised, + return_pred_xstart=True, + ) # MSE for progressive prediction loss assert pred_xstart.shape == x_start.shape - new_mse_b = ((pred_xstart-x_start)**2).mean(dim=list(range(1, len(x_start.shape)))) - assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) + new_mse_b = ((pred_xstart - x_start) ** 2).mean(dim=list(range(1, len(x_start.shape)))) + assert new_vals_b.shape == new_mse_b.shape == torch.Size([B]) # Insert the calculated term into the tensor of all terms - mask_bt = t_b[:, None]==torch.arange(T, device=t_b.device)[None, :].float() + mask_bt = t_b[:, None] == torch.arange(T, device=t_b.device)[None, :].float() vals_bt_ = vals_bt_ * (~mask_bt) + new_vals_b[:, None] * mask_bt mse_bt_ = mse_bt_ * (~mask_bt) + new_mse_b[:, None] * mask_bt assert mask_bt.shape == vals_bt_.shape == vals_bt_.shape == torch.Size([B, T]) prior_bpd_b = self._prior_bpd(x_start) total_bpd_b = vals_bt_.sum(dim=1) + prior_bpd_b - assert vals_bt_.shape == mse_bt_.shape == torch.Size([B, T]) and \ - total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) + assert vals_bt_.shape == mse_bt_.shape == torch.Size( + [B, T] + ) and total_bpd_b.shape == prior_bpd_b.shape == torch.Size([B]) return total_bpd_b.mean(), vals_bt_.mean(), prior_bpd_b.mean(), mse_bt_.mean() @@ -388,39 +434,50 @@ class PVCNN2(PVCNN2Base): ((128, 128, 64), (64, 2, 32)), ] - def __init__(self, num_classes, embed_dim, use_att,dropout, extra_feature_channels=3, width_multiplier=1, - voxel_resolution_multiplier=1): + def __init__( + self, + num_classes, + embed_dim, + use_att, + dropout, + extra_feature_channels=3, + width_multiplier=1, + voxel_resolution_multiplier=1, + ): super().__init__( - num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, - dropout=dropout, extra_feature_channels=extra_feature_channels, - width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + num_classes=num_classes, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, ) class Model(nn.Module): - def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str): + def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type: str): super(Model, self).__init__() self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type) - self.model = PVCNN2(num_classes=args.nc, embed_dim=args.embed_dim, use_att=args.attention, - dropout=args.dropout, extra_feature_channels=0) + self.model = PVCNN2( + num_classes=args.nc, + embed_dim=args.embed_dim, + use_att=args.attention, + dropout=args.dropout, + extra_feature_channels=0, + ) def prior_kl(self, x0): return self.diffusion._prior_bpd(x0) def all_kl(self, x0, clip_denoised=True): - total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) - - return { - 'total_bpd_b': total_bpd_b, - 'terms_bpd': vals_bt, - 'prior_bpd_b': prior_bpd_b, - 'mse_bt':mse_bt - } + total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, clip_denoised) + return {"total_bpd_b": total_bpd_b, "terms_bpd": vals_bt, "prior_bpd_b": prior_bpd_b, "mse_bt": mse_bt} def _denoise(self, data, t): - B, D,N= data.shape + B, D, N = data.shape assert data.dtype == torch.float assert t.shape == torch.Size([B]) and t.dtype == torch.int64 @@ -434,25 +491,32 @@ class Model(nn.Module): t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device) if noises is not None: - noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises) + noises[t != 0] = torch.randn((t != 0).sum(), *noises.shape[1:]).to(noises) - losses = self.diffusion.p_losses( - denoise_fn=self._denoise, data_start=data, t=t, noise=noises) + losses = self.diffusion.p_losses(denoise_fn=self._denoise, data_start=data, t=t, noise=noises) assert losses.shape == t.shape == torch.Size([B]) return losses - def gen_samples(self, shape, device, noise_fn=torch.randn, - clip_denoised=True, - keep_running=False): - return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, noise_fn=noise_fn, - clip_denoised=clip_denoised, - keep_running=keep_running) + def gen_samples(self, shape, device, noise_fn=torch.randn, clip_denoised=True, keep_running=False): + return self.diffusion.p_sample_loop( + self._denoise, + shape=shape, + device=device, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + keep_running=keep_running, + ) - def gen_sample_traj(self, shape, device, freq, noise_fn=torch.randn, - clip_denoised=True,keep_running=False): - return self.diffusion.p_sample_loop_trajectory(self._denoise, shape=shape, device=device, noise_fn=noise_fn, freq=freq, - clip_denoised=clip_denoised, - keep_running=keep_running) + def gen_sample_traj(self, shape, device, freq, noise_fn=torch.randn, clip_denoised=True, keep_running=False): + return self.diffusion.p_sample_loop_trajectory( + self._denoise, + shape=shape, + device=device, + noise_fn=noise_fn, + freq=freq, + clip_denoised=clip_denoised, + keep_running=keep_running, + ) def train(self): self.model.train() @@ -465,20 +529,17 @@ class Model(nn.Module): def get_betas(schedule_type, b_start, b_end, time_num): - if schedule_type == 'linear': + if schedule_type == "linear": betas = np.linspace(b_start, b_end, time_num) - elif schedule_type == 'warm0.1': - + elif schedule_type == "warm0.1": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.1) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.2': - + elif schedule_type == "warm0.2": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.2) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) - elif schedule_type == 'warm0.5': - + elif schedule_type == "warm0.5": betas = b_end * np.ones(time_num, dtype=np.float64) warmup_time = int(time_num * 0.5) betas[:warmup_time] = np.linspace(b_start, b_end, warmup_time, dtype=np.float64) @@ -487,20 +548,25 @@ def get_betas(schedule_type, b_start, b_end, time_num): return betas -def get_dataset(dataroot, npoints,category): - tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=[category], split='train', +def get_dataset(dataroot, npoints, category): + tr_dataset = ShapeNet15kPointClouds( + root_dir=dataroot, + categories=[category], + split="train", tr_sample_size=npoints, te_sample_size=npoints, - scale=1., + scale=1.0, normalize_per_shape=False, normalize_std_per_axis=False, - random_subsample=True) - te_dataset = ShapeNet15kPointClouds(root_dir=dataroot, - categories=[category], split='val', + random_subsample=True, + ) + te_dataset = ShapeNet15kPointClouds( + root_dir=dataroot, + categories=[category], + split="val", tr_sample_size=npoints, te_sample_size=npoints, - scale=1., + scale=1.0, normalize_per_shape=False, normalize_std_per_axis=False, all_points_mean=tr_dataset.all_points_mean, @@ -510,18 +576,13 @@ def get_dataset(dataroot, npoints,category): def get_dataloader(opt, train_dataset, test_dataset=None): - - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=opt.world_size, - rank=opt.rank + train_dataset, num_replicas=opt.world_size, rank=opt.rank ) if test_dataset is not None: test_sampler = torch.utils.data.distributed.DistributedSampler( - test_dataset, - num_replicas=opt.world_size, - rank=opt.rank + test_dataset, num_replicas=opt.world_size, rank=opt.rank ) else: test_sampler = None @@ -529,12 +590,24 @@ def get_dataloader(opt, train_dataset, test_dataset=None): train_sampler = None test_sampler = None - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=train_sampler, - shuffle=train_sampler is None, num_workers=int(opt.workers), drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.bs, + sampler=train_sampler, + shuffle=train_sampler is None, + num_workers=int(opt.workers), + drop_last=True, + ) if test_dataset is not None: - test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.bs,sampler=test_sampler, - shuffle=False, num_workers=int(opt.workers), drop_last=False) + test_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.bs, + sampler=test_sampler, + shuffle=False, + num_workers=int(opt.workers), + drop_last=False, + ) else: test_dataloader = None @@ -542,58 +615,57 @@ def get_dataloader(opt, train_dataset, test_dataset=None): def train(gpu, opt, output_dir, noises_init): - set_seed(opt) logger = setup_logging(output_dir) - if opt.distribution_type == 'multi': - should_diag = gpu==0 + if opt.distribution_type == "multi": + should_diag = gpu == 0 else: should_diag = True if should_diag: - outf_syn, = setup_output_subdirs(output_dir, 'syn') + (outf_syn,) = setup_output_subdirs(output_dir, "syn") - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": if opt.dist_url == "env://" and opt.rank == -1: opt.rank = int(os.environ["RANK"]) - base_rank = opt.rank * opt.ngpus_per_node + base_rank = opt.rank * opt.ngpus_per_node opt.rank = base_rank + gpu - dist.init_process_group(backend=opt.dist_backend, init_method=opt.dist_url, - world_size=opt.world_size, rank=opt.rank) + dist.init_process_group( + backend=opt.dist_backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank + ) opt.bs = int(opt.bs / opt.ngpus_per_node) opt.workers = 0 - opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) + opt.saveIter = int(opt.saveIter / opt.ngpus_per_node) opt.diagIter = int(opt.diagIter / opt.ngpus_per_node) opt.vizIter = int(opt.vizIter / opt.ngpus_per_node) - - ''' data ''' + """ data """ train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) - - ''' + """ create networks - ''' + """ betas = get_betas(opt.schedule_type, opt.beta_start, opt.beta_end, opt.time_num) model = Model(opt, betas, opt.loss_type, opt.model_mean_type, opt.model_var_type) - if opt.distribution_type == 'multi': # Multiple processes, single GPU per process + if opt.distribution_type == "multi": # Multiple processes, single GPU per process + def _transform_(m): - return nn.parallel.DistributedDataParallel( - m, device_ids=[gpu], output_device=gpu) + return nn.parallel.DistributedDataParallel(m, device_ids=[gpu], output_device=gpu) torch.cuda.set_device(gpu) model.cuda(gpu) model.multi_gpu_wrapper(_transform_) + elif opt.distribution_type == "single": - elif opt.distribution_type == 'single': def _transform_(m): return nn.parallel.DataParallel(m) + model = model.cuda() model.multi_gpu_wrapper(_transform_) @@ -601,49 +673,46 @@ def train(gpu, opt, output_dir, noises_init): torch.cuda.set_device(gpu) model = model.cuda(gpu) else: - raise ValueError('distribution_type = multi | single | None') + raise ValueError("distribution_type = multi | single | None") if should_diag: logger.info(opt) - optimizer= optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999)) + optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.decay, betas=(opt.beta1, 0.999)) lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, opt.lr_gamma) - if opt.model != '': + if opt.model != "": ckpt = torch.load(opt.model) - model.load_state_dict(ckpt['model_state']) - optimizer.load_state_dict(ckpt['optimizer_state']) + model.load_state_dict(ckpt["model_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) - if opt.model != '': - start_epoch = torch.load(opt.model)['epoch'] + 1 + if opt.model != "": + start_epoch = torch.load(opt.model)["epoch"] + 1 else: start_epoch = 0 def new_x_chain(x, num_chain): return torch.randn(num_chain, *x.shape[1:], device=x.device) - - for epoch in range(start_epoch, opt.niter): - - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": train_sampler.set_epoch(epoch) lr_scheduler.step(epoch) for i, data in enumerate(dataloader): - x = data['train_points'].transpose(1,2) - noises_batch = noises_init[data['idx']].transpose(1,2) + x = data["train_points"].transpose(1, 2) + noises_batch = noises_init[data["idx"]].transpose(1, 2) - ''' + """ train diffusion - ''' + """ - if opt.distribution_type == 'multi' or (opt.distribution_type is None and gpu is not None): + if opt.distribution_type == "multi" or (opt.distribution_type is None and gpu is not None): x = x.cuda(gpu) noises_batch = noises_batch.cuda(gpu) - elif opt.distribution_type == 'single': + elif opt.distribution_type == "single": x = x.cuda() noises_batch = noises_batch.cuda() @@ -657,44 +726,47 @@ def train(gpu, opt, output_dir, noises_init): optimizer.step() - if i % opt.print_freq == 0 and should_diag: - - logger.info('[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, ' - 'netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ' - .format( - epoch, opt.niter, i, len(dataloader),loss.item(), - netpNorm, netgradNorm, - )) - + logger.info( + "[{:>3d}/{:>3d}][{:>3d}/{:>3d}] loss: {:>10.4f}, " + "netpNorm: {:>10.2f}, netgradNorm: {:>10.2f} ".format( + epoch, + opt.niter, + i, + len(dataloader), + loss.item(), + netpNorm, + netgradNorm, + ) + ) if (epoch + 1) % opt.diagIter == 0 and should_diag: - - logger.info('Diagnosis:') + logger.info("Diagnosis:") x_range = [x.min().item(), x.max().item()] kl_stats = model.all_kl(x) - logger.info(' [{:>3d}/{:>3d}] ' - 'x_range: [{:>10.4f}, {:>10.4f}], ' - 'total_bpd_b: {:>10.4f}, ' - 'terms_bpd: {:>10.4f}, ' - 'prior_bpd_b: {:>10.4f} ' - 'mse_bt: {:>10.4f} ' - .format( - epoch, opt.niter, - *x_range, - kl_stats['total_bpd_b'].item(), - kl_stats['terms_bpd'].item(), kl_stats['prior_bpd_b'].item(), kl_stats['mse_bt'].item() - )) - - + logger.info( + " [{:>3d}/{:>3d}] " + "x_range: [{:>10.4f}, {:>10.4f}], " + "total_bpd_b: {:>10.4f}, " + "terms_bpd: {:>10.4f}, " + "prior_bpd_b: {:>10.4f} " + "mse_bt: {:>10.4f} ".format( + epoch, + opt.niter, + *x_range, + kl_stats["total_bpd_b"].item(), + kl_stats["terms_bpd"].item(), + kl_stats["prior_bpd_b"].item(), + kl_stats["mse_bt"].item(), + ) + ) if (epoch + 1) % opt.vizIter == 0 and should_diag: - logger.info('Generation: eval') + logger.info("Generation: eval") model.eval() with torch.no_grad(): - x_gen_eval = model.gen_samples(new_x_chain(x, 25).shape, x.device, clip_denoised=False) x_gen_list = model.gen_sample_traj(new_x_chain(x, 1).shape, x.device, freq=40, clip_denoised=False) x_gen_all = torch.cat(x_gen_list, dim=0) @@ -702,79 +774,70 @@ def train(gpu, opt, output_dir, noises_init): gen_stats = [x_gen_eval.mean(), x_gen_eval.std()] gen_eval_range = [x_gen_eval.min().item(), x_gen_eval.max().item()] - logger.info(' [{:>3d}/{:>3d}] ' - 'eval_gen_range: [{:>10.4f}, {:>10.4f}] ' - 'eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ' - .format( - epoch, opt.niter, - *gen_eval_range, *gen_stats, - )) + logger.info( + " [{:>3d}/{:>3d}] " + "eval_gen_range: [{:>10.4f}, {:>10.4f}] " + "eval_gen_stats: [mean={:>10.4f}, std={:>10.4f}] ".format( + epoch, + opt.niter, + *gen_eval_range, + *gen_stats, + ) + ) - visualize_pointcloud_batch('%s/epoch_%03d_samples_eval.png' % (outf_syn, epoch), - x_gen_eval.transpose(1, 2), None, None, - None) + visualize_pointcloud_batch( + "%s/epoch_%03d_samples_eval.png" % (outf_syn, epoch), x_gen_eval.transpose(1, 2), None, None, None + ) - visualize_pointcloud_batch('%s/epoch_%03d_samples_eval_all.png' % (outf_syn, epoch), - x_gen_all.transpose(1, 2), None, - None, - None) + visualize_pointcloud_batch( + "%s/epoch_%03d_samples_eval_all.png" % (outf_syn, epoch), x_gen_all.transpose(1, 2), None, None, None + ) - visualize_pointcloud_batch('%s/epoch_%03d_x.png' % (outf_syn, epoch), x.transpose(1, 2), None, - None, - None) + visualize_pointcloud_batch("%s/epoch_%03d_x.png" % (outf_syn, epoch), x.transpose(1, 2), None, None, None) - logger.info('Generation: train') + logger.info("Generation: train") model.train() - - - - - - - if (epoch + 1) % opt.saveIter == 0: - if should_diag: - - save_dict = { - 'epoch': epoch, - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + "epoch": epoch, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), } - torch.save(save_dict, '%s/epoch_%d.pth' % (output_dir, epoch)) + torch.save(save_dict, "%s/epoch_%d.pth" % (output_dir, epoch)) - - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": dist.barrier() - map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu} + map_location = {"cuda:%d" % 0: "cuda:%d" % gpu} model.load_state_dict( - torch.load('%s/epoch_%d.pth' % (output_dir, epoch), map_location=map_location)['model_state']) + torch.load("%s/epoch_%d.pth" % (output_dir, epoch), map_location=map_location)["model_state"] + ) dist.destroy_process_group() + def main(): opt = parse_args() - if opt.category == 'airplane': + if opt.category == "airplane": opt.beta_start = 1e-5 opt.beta_end = 0.008 - opt.schedule_type = 'warm0.1' + opt.schedule_type = "warm0.1" exp_id = os.path.splitext(os.path.basename(__file__))[0] dir_id = os.path.dirname(__file__) output_dir = get_output_dir(dir_id, exp_id) copy_source(__file__, output_dir) - ''' workaround ''' + """ workaround """ train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) noises_init = torch.randn(len(train_dataset), opt.npoints, opt.nc) if opt.dist_url == "env://" and opt.world_size == -1: opt.world_size = int(os.environ["WORLD_SIZE"]) - if opt.distribution_type == 'multi': + if opt.distribution_type == "multi": opt.ngpus_per_node = torch.cuda.device_count() opt.world_size = opt.ngpus_per_node * opt.world_size mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) @@ -782,71 +845,69 @@ def main(): train(opt.gpu, opt, output_dir, noises_init) - def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--dataroot', default='ShapeNetCore.v2.PC15k/') - parser.add_argument('--category', default='chair') + parser.add_argument("--dataroot", default="ShapeNetCore.v2.PC15k/") + parser.add_argument("--category", default="chair") - parser.add_argument('--bs', type=int, default=16, help='input batch size') - parser.add_argument('--workers', type=int, default=16, help='workers') - parser.add_argument('--niter', type=int, default=10000, help='number of epochs to train for') + parser.add_argument("--bs", type=int, default=16, help="input batch size") + parser.add_argument("--workers", type=int, default=16, help="workers") + parser.add_argument("--niter", type=int, default=10000, help="number of epochs to train for") - parser.add_argument('--nc', default=3) - parser.add_argument('--npoints', default=2048) - '''model''' - parser.add_argument('--beta_start', default=0.0001) - parser.add_argument('--beta_end', default=0.02) - parser.add_argument('--schedule_type', default='linear') - parser.add_argument('--time_num', default=1000) + parser.add_argument("--nc", default=3) + parser.add_argument("--npoints", default=2048) + """model""" + parser.add_argument("--beta_start", default=0.0001) + parser.add_argument("--beta_end", default=0.02) + parser.add_argument("--schedule_type", default="linear") + parser.add_argument("--time_num", default=1000) - #params - parser.add_argument('--attention', default=True) - parser.add_argument('--dropout', default=0.1) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--loss_type', default='mse') - parser.add_argument('--model_mean_type', default='eps') - parser.add_argument('--model_var_type', default='fixedsmall') + # params + parser.add_argument("--attention", default=True) + parser.add_argument("--dropout", default=0.1) + parser.add_argument("--embed_dim", type=int, default=64) + parser.add_argument("--loss_type", default="mse") + parser.add_argument("--model_mean_type", default="eps") + parser.add_argument("--model_var_type", default="fixedsmall") - parser.add_argument('--lr', type=float, default=2e-4, help='learning rate for E, default=0.0002') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') - parser.add_argument('--decay', type=float, default=0, help='weight decay for EBM') - parser.add_argument('--grad_clip', type=float, default=None, help='weight decay for EBM') - parser.add_argument('--lr_gamma', type=float, default=0.998, help='lr decay for EBM') + parser.add_argument("--lr", type=float, default=2e-4, help="learning rate for E, default=0.0002") + parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5") + parser.add_argument("--decay", type=float, default=0, help="weight decay for EBM") + parser.add_argument("--grad_clip", type=float, default=None, help="weight decay for EBM") + parser.add_argument("--lr_gamma", type=float, default=0.998, help="lr decay for EBM") - parser.add_argument('--model', default='', help="path to model (to continue training)") + parser.add_argument("--model", default="", help="path to model (to continue training)") + """distributed""" + parser.add_argument("--world_size", default=1, type=int, help="Number of distributed nodes.") + parser.add_argument( + "--dist_url", default="tcp://127.0.0.1:9991", type=str, help="url used to set up distributed training" + ) + parser.add_argument("--dist_backend", default="nccl", type=str, help="distributed backend") + parser.add_argument( + "--distribution_type", + default="single", + choices=["multi", "single", None], + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", + ) + parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") + parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.") - '''distributed''' - parser.add_argument('--world_size', default=1, type=int, - help='Number of distributed nodes.') - parser.add_argument('--dist_url', default='tcp://127.0.0.1:9991', type=str, - help='url used to set up distributed training') - parser.add_argument('--dist_backend', default='nccl', type=str, - help='distributed backend') - parser.add_argument('--distribution_type', default='single', choices=['multi', 'single', None], - help='Use multi-processing distributed training to launch ' - 'N processes per node, which has N GPUs. This is the ' - 'fastest way to use PyTorch for either single node or ' - 'multi node data parallel training') - parser.add_argument('--rank', default=0, type=int, - help='node rank for distributed training') - parser.add_argument('--gpu', default=None, type=int, - help='GPU id to use. None means using all available GPUs.') - - '''eval''' - parser.add_argument('--saveIter', default=100, help='unit: epoch') - parser.add_argument('--diagIter', default=50, help='unit: epoch') - parser.add_argument('--vizIter', default=50, help='unit: epoch') - parser.add_argument('--print_freq', default=50, help='unit: iter') - - parser.add_argument('--manualSeed', default=42, type=int, help='random seed') + """eval""" + parser.add_argument("--saveIter", default=100, help="unit: epoch") + parser.add_argument("--diagIter", default=50, help="unit: epoch") + parser.add_argument("--vizIter", default=50, help="unit: epoch") + parser.add_argument("--print_freq", default=50, help="unit: iter") + parser.add_argument("--manualSeed", default=42, type=int, help="random seed") opt = parser.parse_args() return opt -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/utils/file_utils.py b/utils/file_utils.py index e6dbe67..b859bf8 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -1,35 +1,33 @@ +import datetime +import logging import os import random import sys - from shutil import copyfile -import datetime import torch -import logging logger = logging.getLogger() import numpy as np + def set_global_gpu_env(opt): - - os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' - os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu) - + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu) torch.cuda.set_device(opt.gpu) + def copy_source(file, output_dir): copyfile(file, os.path.join(output_dir, os.path.basename(file))) - def setup_logging(output_dir): log_format = logging.Formatter("%(asctime)s : %(message)s") logger = logging.getLogger() logger.handlers = [] - output_file = os.path.join(output_dir, 'output.log') + output_file = os.path.join(output_dir, "output.log") file_handler = logging.FileHandler(output_file) file_handler.setFormatter(log_format) logger.addHandler(file_handler) @@ -44,16 +42,14 @@ def setup_logging(output_dir): def get_output_dir(prefix, exp_id): - t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') - output_dir = os.path.join(prefix, 'output/' + exp_id, t) + t = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + output_dir = os.path.join(prefix, "output/" + exp_id, t) if not os.path.exists(output_dir): os.makedirs(output_dir) return output_dir - def set_seed(opt): - if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) print("Random Seed: ", opt.manualSeed) @@ -65,8 +61,8 @@ def set_seed(opt): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True -def setup_output_subdirs(output_dir, *subfolders): +def setup_output_subdirs(output_dir, *subfolders): output_subdirs = output_dir try: os.makedirs(output_subdirs) @@ -82,4 +78,4 @@ def setup_output_subdirs(output_dir, *subfolders): pass subfolder_list.append(curr_subf) - return subfolder_list \ No newline at end of file + return subfolder_list diff --git a/utils/metrics.py b/utils/metrics.py index ec25e26..2feb1dd 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,20 +1,22 @@ -import numpy as np import warnings +import numpy as np from scipy.stats import entropy + def iterate_in_chunks(l, n): - '''Yield successive 'n'-sized chunks from iterable 'l'. + """Yield successive 'n'-sized chunks from iterable 'l'. Note: last chunk will be smaller than l if n doesn't divide l perfectly. - ''' + """ for i in range(0, len(l), n): - yield l[i:i + n] + yield l[i : i + n] + def unit_cube_grid_point_cloud(resolution, clip_sphere=False): - '''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, + """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. - ''' + """ grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) spacing = 1.0 / float(resolution - 1) for i in range(resolution): @@ -30,9 +32,11 @@ def unit_cube_grid_point_cloud(resolution, clip_sphere=False): return grid, spacing -def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, - use_EMD=False): - '''Computes the MMD between two sets of point-clouds. + +def minimum_mathing_distance( + sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False +): + """Computes the MMD between two sets of point-clouds. Args: sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and compared to a set of "reference" point-clouds. @@ -49,17 +53,17 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se use_EMD (boolean: If true, the matchings are based on the EMD. Returns: A tuple containing the MMD and all the matched distances of which the MMD is their mean. - ''' + """ n_ref, n_pc_points, pc_dim = ref_pcs.shape _, n_pc_points_s, pc_dim_s = sample_pcs.shape if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: - raise ValueError('Incompatible size of point-clouds.') + raise ValueError("Incompatible size of point-clouds.") - ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph(n_pc_points, normalize=normalize, - sess=sess, use_sqrt=use_sqrt, - use_EMD=use_EMD) + ref_pl, sample_pl, best_in_batch, _, sess = minimum_mathing_distance_tf_graph( + n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD + ) matched_dists = [] for i in range(n_ref): best_in_all_batches = [] @@ -75,9 +79,18 @@ def minimum_mathing_distance(sample_pcs, ref_pcs, batch_size, normalize=True, se return mmd, matched_dists -def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose=False, use_sqrt=False, use_EMD=False, - ret_dist=False): - '''Computes the Coverage between two sets of point-clouds. +def coverage( + sample_pcs, + ref_pcs, + batch_size, + normalize=True, + sess=None, + verbose=False, + use_sqrt=False, + use_EMD=False, + ret_dist=False, +): + """Computes the Coverage between two sets of point-clouds. Args: sample_pcs (numpy array SxKx3): the S point-clouds, each of K points that will be matched and compared to a set of "reference" point-clouds. @@ -97,18 +110,16 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose Returns: the coverage score (int), the indices of the ref_pcs that are matched with each sample_pc and optionally the matched distances of the samples_pcs. - ''' + """ n_ref, n_pc_points, pc_dim = ref_pcs.shape n_sam, n_pc_points_s, pc_dim_s = sample_pcs.shape if n_pc_points != n_pc_points_s or pc_dim != pc_dim_s: - raise ValueError('Incompatible Point-Clouds.') + raise ValueError("Incompatible Point-Clouds.") - ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph(n_pc_points, - normalize=normalize, - sess=sess, - use_sqrt=use_sqrt, - use_EMD=use_EMD) + ref_pl, sample_pl, best_in_batch, loc_of_best, sess = minimum_mathing_distance_tf_graph( + n_pc_points, normalize=normalize, sess=sess, use_sqrt=use_sqrt, use_EMD=use_EMD + ) matched_gt = [] matched_dist = [] for i in xrange(n_sam): @@ -140,12 +151,12 @@ def coverage(sample_pcs, ref_pcs, batch_size, normalize=True, sess=None, verbose def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): - '''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```. + """Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```. Args: sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. resolution: (int) grid-resolution. Affects granularity of measurements. - ''' + """ in_unit_sphere = True sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] @@ -153,19 +164,19 @@ def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28): def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): - '''Given a collection of point-clouds, estimate the entropy of the random variables + """Given a collection of point-clouds, estimate the entropy of the random variables corresponding to occupancy-grid activation patterns. Inputs: pclouds: (numpy array) #point-clouds x points per point-cloud x 3 grid_resolution (int) size of occupancy grid that will be used. - ''' + """ epsilon = 10e-4 bound = 0.5 + epsilon if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: - warnings.warn('Point-clouds are not in unit cube.') + warnings.warn("Point-clouds are not in unit cube.") - if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: - warnings.warn('Point-clouds are not in unit sphere.') + if in_sphere and np.max(np.sqrt(np.sum(pclouds**2, axis=2))) > bound: + warnings.warn("Point-clouds are not in unit sphere.") grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) grid_coordinates = grid_coordinates.reshape(-1, 3) @@ -192,13 +203,14 @@ def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): return acc_entropy / len(grid_counters), grid_counters + def jensen_shannon_divergence(P, Q): if np.any(P < 0) or np.any(Q < 0): - raise ValueError('Negative values.') + raise ValueError("Negative values.") if len(P) != len(Q): - raise ValueError('Non equal size.') + raise ValueError("Non equal size.") - P_ = P / np.sum(P) # Ensure probabilities. + P_ = P / np.sum(P) # Ensure probabilities. Q_ = Q / np.sum(Q) e1 = entropy(P_, base=2) @@ -209,13 +221,14 @@ def jensen_shannon_divergence(P, Q): res2 = _jsdiv(P_, Q_) if not np.allclose(res, res2, atol=10e-5, rtol=0): - warnings.warn('Numerical values of two JSD methods don\'t agree.') + warnings.warn("Numerical values of two JSD methods don't agree.") return res def _jsdiv(P, Q): - '''another way of computing JSD''' + """another way of computing JSD""" + def _kldiv(A, B): a = A.copy() b = B.copy() @@ -229,4 +242,4 @@ def _jsdiv(P, Q): M = 0.5 * (P_ + Q_) - return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) \ No newline at end of file + return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) diff --git a/utils/visualize.py b/utils/visualize.py index 8153f8a..900022c 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -1,32 +1,33 @@ import matplotlib -matplotlib.use('agg') -import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D -from mpl_toolkits.mplot3d.art3d import Poly3DCollection -import numpy as np + +matplotlib.use("agg") import os -import trimesh from pathlib import Path -''' +import matplotlib.pyplot as plt +import numpy as np +import trimesh +from mpl_toolkits.mplot3d import Axes3D + +""" Custom visualization -''' +""" + def export_to_pc_batch(dir, pcs, colors=None): - Path(dir).mkdir(parents=True, exist_ok=True) for i, xyz in enumerate(pcs): if colors is None: color = None else: color = colors[i] - pcwrite(os.path.join(dir, 'sample_'+str(i)+'.ply'), xyz, color) + pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color) -def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)): - ''' +def export_to_obj(dir, meshes, transform=lambda v, f: (v, f)): + """ transform: f(vertices, faces) --> transformed (vertices, faces) - ''' + """ Path(dir).mkdir(parents=True, exist_ok=True) for i, data in enumerate(meshes): v, f = transform(data[0], data[1]) @@ -36,14 +37,15 @@ def export_to_obj(dir, meshes, transform=lambda v,f:(v,f)): v_color = None mesh = trimesh.Trimesh(v, f, vertex_colors=v_color) out = trimesh.exchange.obj.export_obj(mesh) - with open(os.path.join(dir, 'sample_'+str(i)+'.obj'), 'w') as f: + with open(os.path.join(dir, "sample_" + str(i) + ".obj"), "w") as f: f.write(out) f.close() -def export_to_obj_single(path, data, transform=lambda v,f:(v,f)): - ''' + +def export_to_obj_single(path, data, transform=lambda v, f: (v, f)): + """ transform: f(vertices, faces) --> transformed (vertices, faces) - ''' + """ v, f = transform(data[0], data[1]) if len(data) > 2: v_color = data[2] @@ -51,15 +53,15 @@ def export_to_obj_single(path, data, transform=lambda v,f:(v,f)): v_color = None mesh = trimesh.Trimesh(v, f, vertex_colors=v_color) out = trimesh.exchange.obj.export_obj(mesh) - with open(path, 'w') as f: + with open(path, "w") as f: f.write(out) f.close() + def meshwrite(filename, verts, faces, norms, colors): - """Save a 3D mesh to a polygon .ply file. - """ + """Save a 3D mesh to a polygon .ply file.""" # Write header - ply_file = open(filename, 'w') + ply_file = open(filename, "w") ply_file.write("ply\n") ply_file.write("format ascii 1.0\n") ply_file.write("element vertex %d\n" % (verts.shape[0])) @@ -78,11 +80,20 @@ def meshwrite(filename, verts, faces, norms, colors): # Write vertex list for i in range(verts.shape[0]): - ply_file.write("%f %f %f %f %f %f %d %d %d\n" % ( - verts[i, 0], verts[i, 1], verts[i, 2], - norms[i, 0], norms[i, 1], norms[i, 2], - colors[i, 0], colors[i, 1], colors[i, 2], - )) + ply_file.write( + "%f %f %f %f %f %f %d %d %d\n" + % ( + verts[i, 0], + verts[i, 1], + verts[i, 2], + norms[i, 0], + norms[i, 1], + norms[i, 2], + colors[i, 0], + colors[i, 1], + colors[i, 2], + ) + ) # Write face list for i in range(faces.shape[0]): @@ -92,14 +103,13 @@ def meshwrite(filename, verts, faces, norms, colors): def pcwrite(filename, xyz, rgb=None): - """Save a point cloud to a polygon .ply file. - """ + """Save a point cloud to a polygon .ply file.""" if rgb is None: rgb = np.ones_like(xyz) * 128 rgb = rgb.astype(np.uint8) # Write header - ply_file = open(filename, 'w') + ply_file = open(filename, "w") ply_file.write("ply\n") ply_file.write("format ascii 1.0\n") ply_file.write("element vertex %d\n" % (xyz.shape[0])) @@ -113,60 +123,67 @@ def pcwrite(filename, xyz, rgb=None): # Write vertex list for i in range(xyz.shape[0]): - ply_file.write("%f %f %f %d %d %d\n" % ( - xyz[i, 0], xyz[i, 1], xyz[i, 2], - rgb[i, 0], rgb[i, 1], rgb[i, 2], - )) + ply_file.write( + "%f %f %f %d %d %d\n" + % ( + xyz[i, 0], + xyz[i, 1], + xyz[i, 2], + rgb[i, 0], + rgb[i, 1], + rgb[i, 2], + ) + ) -''' + +""" Matplotlib Visualization -''' +""" + def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5): - r''' Visualizes voxel data. + r"""Visualizes voxel data. show only first num_shown - ''' - batch_size =voxels.shape[0] + """ + batch_size = voxels.shape[0] voxels = voxels.squeeze(1) > threshold num_shown = min(num_shown, batch_size) n = int(np.sqrt(num_shown)) - fig = plt.figure(figsize=(20,20)) + fig = plt.figure(figsize=(20, 20)) for idx, pc in enumerate(voxels[:num_shown]): - if idx >= n*n: + if idx >= n * n: break pc = voxels[idx] - ax = fig.add_subplot(n, n, idx + 1, projection='3d') - ax.voxels(pc, edgecolor='k', facecolors='green', linewidth=0.1, alpha=0.5) + ax = fig.add_subplot(n, n, idx + 1, projection="3d") + ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5) ax.view_init() - ax.axis('off') - plt.savefig(out_file, bbox_inches='tight') + ax.axis("off") + plt.savefig(out_file, bbox_inches="tight") plt.close() -def visualize_pointcloud(points, normals=None, - out_file=None, show=False, elev=30, azim=225): - r''' Visualizes point cloud data. + +def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225): + r"""Visualizes point cloud data. Args: points (tensor): point data normals (tensor): normal data (if existing) out_file (string): output file show (bool): whether the plot should be shown - ''' + """ # Create plot fig = plt.figure() ax = fig.gca(projection=Axes3D.name) ax.scatter(points[:, 2], points[:, 0], points[:, 1]) if normals is not None: ax.quiver( - points[:, 2], points[:, 0], points[:, 1], - normals[:, 2], normals[:, 0], normals[:, 1], - length=0.1, color='k' + points[:, 2], points[:, 0], points[:, 1], normals[:, 2], normals[:, 0], normals[:, 1], length=0.1, color="k" ) - ax.set_xlabel('Z') - ax.set_ylabel('X') - ax.set_zlabel('Y') + ax.set_xlabel("Z") + ax.set_ylabel("X") + ax.set_zlabel("Y") # ax.set_xlim(-0.5, 0.5) # ax.set_ylim(-0.5, 0.5) # ax.set_zlim(-0.5, 0.5) @@ -178,37 +195,39 @@ def visualize_pointcloud(points, normals=None, plt.close(fig) -def visualize_pointcloud_batch(path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225): +def visualize_pointcloud_batch( + path, pointclouds, pred_labels, labels, categories, vis_label=False, target=None, elev=30, azim=225 +): batch_size = len(pointclouds) - fig = plt.figure(figsize=(20,20)) + fig = plt.figure(figsize=(20, 20)) ncols = int(np.sqrt(batch_size)) - nrows = max(1, (batch_size-1) // ncols+1) + nrows = max(1, (batch_size - 1) // ncols + 1) for idx, pc in enumerate(pointclouds): if vis_label: label = categories[labels[idx].item()] pred = categories[pred_labels[idx]] - colour = 'g' if label == pred else 'r' + colour = "g" if label == pred else "r" elif target is None: - - colour = 'g' + colour = "g" else: colour = target[idx] pc = pc.cpu().numpy() - ax = fig.add_subplot(nrows, ncols, idx + 1, projection='3d') + ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d") ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5) ax.view_init(elev=elev, azim=azim) - ax.axis('off') + ax.axis("off") if vis_label: - ax.set_title('GT: {0}\nPred: {1}'.format(label, pred)) + ax.set_title("GT: {0}\nPred: {1}".format(label, pred)) plt.savefig(path) plt.close(fig) -''' +""" Plot stats -''' +""" + def plot_stats(output_dir, stats, interval): content = stats.keys() @@ -218,5 +237,5 @@ def plot_stats(output_dir, stats, interval): axs[j].plot(interval, v) axs[j].set_ylabel(k) - f.savefig(os.path.join(output_dir, 'stat.pdf'), bbox_inches='tight') + f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight") plt.close(f)