init
This commit is contained in:
parent
1118b9b0c0
commit
1d24a7879d
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
__pycache__/
|
||||
*__pycache__
|
||||
.idea/
|
||||
*.pyc
|
||||
*.m
|
||||
.ipynb_checkpoints
|
||||
*swp
|
||||
*swo
|
||||
*__pycache__*
|
||||
models/pvcnn/functional/build/
|
||||
*.sh
|
||||
lion_ckpt
|
||||
data/
|
||||
datasets/test_data
|
76
README.md
76
README.md
|
@ -1,20 +1,76 @@
|
|||
## <p align="center">LION: Latent Point Diffusion Models for 3D Shape Generation<br><br> NeurIPS 2022 </p>
|
||||
<div align="center">
|
||||
<a href="https://www.cs.utoronto.ca/~xiaohui/" target="_blank">Xiaohui Zeng</a>   <b>·</b>  
|
||||
<a href="http://latentspace.cc/" target="_blank">Arash Vahdat</a>   <b>·</b>  
|
||||
<a href="https://www.fwilliams.info/" target="_blank">Francis Williams</a>   <b>·</b>  
|
||||
<a href="https://zgojcic.github.io/" target="_blank">Zan Gojcic</a>   <b>·</b>  
|
||||
<a href="https://orlitany.github.io/" target="_blank">Or Litany</a>   <b>·</b>  
|
||||
<a href="https://www.cs.utoronto.ca/~fidler/" target="_blank">Sanja Fidler</a>   <b>·</b>  
|
||||
<a href="https://www.cs.utoronto.ca/~xiaohui/" target="_blank">Xiaohui Zeng</a>  
|
||||
<a href="http://latentspace.cc/" target="_blank">Arash Vahdat</a>  
|
||||
<a href="https://www.fwilliams.info/" target="_blank">Francis Williams</a>  
|
||||
<a href="https://zgojcic.github.io/" target="_blank">Zan Gojcic</a>  
|
||||
<a href="https://orlitany.github.io/" target="_blank">Or Litany</a>  
|
||||
<a href="https://www.cs.utoronto.ca/~fidler/" target="_blank">Sanja Fidler</a>  
|
||||
<a href="https://karstenkreis.github.io/" target="_blank">Karsten Kreis</a>
|
||||
<br> <br>
|
||||
<a href="https://arxiv.org/abs/2210.06978" target="_blank">Paper</a>  
|
||||
<a href="https://nv-tlabs.github.io/LION" target="_blank">Project Page</a>
|
||||
</div>
|
||||
<br><br>
|
||||
<p align="center">:construction: :pick: :hammer_and_wrench: :construction_worker:</p>
|
||||
<p align="center">Here, we will release code and checkpoints in the near future! Stay tuned!</p>
|
||||
<br><br>
|
||||
|
||||
<p align="center">
|
||||
<img width="750" alt="Animation" src="assets/animation.gif"/>
|
||||
</p>
|
||||
## Install
|
||||
* Dependencies:
|
||||
* CUDA 11.6
|
||||
|
||||
* Setup the environment
|
||||
Install from conda file
|
||||
```
|
||||
conda env create --name lion_env --file=env.yaml
|
||||
conda activate lion_env
|
||||
|
||||
# Install some other packages
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
|
||||
# build some packages first (optional)
|
||||
python build_pkg.py
|
||||
```
|
||||
Tested with conda version 22.9.0
|
||||
|
||||
## Demo
|
||||
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud.
|
||||
|
||||
## Released checkpoint and samples
|
||||
* will be release soon
|
||||
* put the downloaded file under `./lion_ckpt/`
|
||||
|
||||
## Training
|
||||
|
||||
### data
|
||||
* ShapeNet can be downloaded [here](https://github.com/stevenygd/PointFlow#dataset).
|
||||
* Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path.
|
||||
|
||||
### train VAE
|
||||
* run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4`)
|
||||
|
||||
### train diffusion prior
|
||||
* require the vae checkpoint
|
||||
* run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node)
|
||||
|
||||
### evaluate a trained prior
|
||||
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`
|
||||
* download the released checkpoint from above
|
||||
```
|
||||
checkpoint="./lion_ckpt/unconditional/airplane/checkpoints/model.pt"
|
||||
bash ./script/eval.sh $checkpoint # will take 1-2 hour
|
||||
```
|
||||
|
||||
## Evaluate the samples with the 1-NNA metrics
|
||||
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`
|
||||
* run `python ./script/compute_score.py`
|
||||
|
||||
## Citation
|
||||
```
|
||||
@inproceedings{zeng2022lion,
|
||||
title={LION: Latent Point Diffusion Models for 3D Shape Generation},
|
||||
author={Xiaohui Zeng and Arash Vahdat and Francis Williams and Zan Gojcic and Or Litany and Sanja Fidler and Karsten Kreis},
|
||||
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
|
3
build_pkg.py
Normal file
3
build_pkg.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
import clip
|
||||
from models import pvcnn2
|
||||
from utils import eval_helper
|
38
datasets/data_path.py
Normal file
38
datasets/data_path.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import os
|
||||
|
||||
|
||||
def get_path(dataname=None):
|
||||
dataset_path = {}
|
||||
dataset_path['pointflow'] = [
|
||||
'./data/ShapeNetCore.v2.PC15k/'
|
||||
|
||||
]
|
||||
|
||||
if dataname is None:
|
||||
return dataset_path
|
||||
else:
|
||||
assert(
|
||||
dataname in dataset_path), f'not found {dataname}, only: {list(dataset_path.keys())}'
|
||||
for p in dataset_path[dataname]:
|
||||
print(f'searching: {dataname}, get: {p}')
|
||||
if os.path.exists(p):
|
||||
return p
|
||||
ValueError(
|
||||
f'all path not found for {dataname}, please double check: {dataset_path[dataname]}; or edit the datasets/data_path.py ')
|
||||
|
||||
|
||||
def get_cache_path():
|
||||
cache_list = ['/workspace/data_cache_local/data_stat/',
|
||||
'/workspace/data_cache/data_stat/']
|
||||
for p in cache_list:
|
||||
if os.path.exists(p):
|
||||
return p
|
||||
ValueError(
|
||||
f'all path not found for {cache_list}, please double check: or edit the datasets/data_path.py ')
|
404
datasets/pointflow_datasets.py
Normal file
404
datasets/pointflow_datasets.py
Normal file
|
@ -0,0 +1,404 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
|
||||
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
|
||||
import os
|
||||
import open3d as o3d
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils import data
|
||||
import random
|
||||
import tqdm
|
||||
from datasets.data_path import get_path
|
||||
OVERFIT = 0
|
||||
|
||||
# taken from https://github.com/optas/latent_3d_points/blob/
|
||||
# 8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
|
||||
synsetid_to_cate = {
|
||||
'02691156': 'airplane',
|
||||
'02773838': 'bag',
|
||||
'02801938': 'basket',
|
||||
'02808440': 'bathtub',
|
||||
'02818832': 'bed',
|
||||
'02828884': 'bench',
|
||||
'02876657': 'bottle',
|
||||
'02880940': 'bowl',
|
||||
'02924116': 'bus',
|
||||
'02933112': 'cabinet',
|
||||
'02747177': 'can',
|
||||
'02942699': 'camera',
|
||||
'02954340': 'cap',
|
||||
'02958343': 'car',
|
||||
'03001627': 'chair',
|
||||
'03046257': 'clock',
|
||||
'03207941': 'dishwasher',
|
||||
'03211117': 'monitor',
|
||||
'04379243': 'table',
|
||||
'04401088': 'telephone',
|
||||
'02946921': 'tin_can',
|
||||
'04460130': 'tower',
|
||||
'04468005': 'train',
|
||||
'03085013': 'keyboard',
|
||||
'03261776': 'earphone',
|
||||
'03325088': 'faucet',
|
||||
'03337140': 'file',
|
||||
'03467517': 'guitar',
|
||||
'03513137': 'helmet',
|
||||
'03593526': 'jar',
|
||||
'03624134': 'knife',
|
||||
'03636649': 'lamp',
|
||||
'03642806': 'laptop',
|
||||
'03691459': 'speaker',
|
||||
'03710193': 'mailbox',
|
||||
'03759954': 'microphone',
|
||||
'03761084': 'microwave',
|
||||
'03790512': 'motorcycle',
|
||||
'03797390': 'mug',
|
||||
'03928116': 'piano',
|
||||
'03938244': 'pillow',
|
||||
'03948459': 'pistol',
|
||||
'03991062': 'pot',
|
||||
'04004475': 'printer',
|
||||
'04074963': 'remote_control',
|
||||
'04090263': 'rifle',
|
||||
'04099429': 'rocket',
|
||||
'04225987': 'skateboard',
|
||||
'04256520': 'sofa',
|
||||
'04330267': 'stove',
|
||||
'04530566': 'vessel',
|
||||
'04554684': 'washer',
|
||||
'02992529': 'cellphone',
|
||||
'02843684': 'birdhouse',
|
||||
'02871439': 'bookshelf',
|
||||
# '02858304': 'boat', no boat in our dataset, merged into vessels
|
||||
# '02834778': 'bicycle', not in our taxonomy
|
||||
}
|
||||
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
|
||||
|
||||
|
||||
class ShapeNet15kPointClouds(Dataset):
|
||||
def __init__(self,
|
||||
categories=['airplane'],
|
||||
tr_sample_size=10000,
|
||||
te_sample_size=10000,
|
||||
split='train',
|
||||
scale=1.,
|
||||
normalize_per_shape=False,
|
||||
normalize_shape_box=False,
|
||||
random_subsample=False,
|
||||
sample_with_replacement=1,
|
||||
normalize_std_per_axis=False,
|
||||
normalize_global=False,
|
||||
recenter_per_shape=False,
|
||||
all_points_mean=None,
|
||||
all_points_std=None,
|
||||
input_dim=3,
|
||||
):
|
||||
self.normalize_shape_box = normalize_shape_box
|
||||
root_dir = get_path('pointflow')
|
||||
self.root_dir = root_dir
|
||||
logger.info('[DATA] cat: {}, split: {}, full path: {}; norm global={}, norm-box={}',
|
||||
categories, split, self.root_dir, normalize_global, normalize_shape_box)
|
||||
|
||||
self.split = split
|
||||
assert self.split in ['train', 'test', 'val']
|
||||
self.tr_sample_size = tr_sample_size
|
||||
self.te_sample_size = te_sample_size
|
||||
if type(categories) is str:
|
||||
categories = [categories]
|
||||
self.cates = categories
|
||||
|
||||
if 'all' in categories:
|
||||
self.synset_ids = list(cate_to_synsetid.values())
|
||||
else:
|
||||
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
|
||||
subdirs = self.synset_ids
|
||||
# assert 'v2' in root_dir, "Only supporting v2 right now."
|
||||
self.gravity_axis = 1
|
||||
self.display_axis_order = [0, 2, 1]
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.split = split
|
||||
self.in_tr_sample_size = tr_sample_size
|
||||
self.in_te_sample_size = te_sample_size
|
||||
self.subdirs = subdirs
|
||||
self.scale = scale
|
||||
self.random_subsample = random_subsample
|
||||
self.sample_with_replacement = sample_with_replacement
|
||||
self.input_dim = input_dim
|
||||
|
||||
self.all_cate_mids = []
|
||||
self.cate_idx_lst = []
|
||||
self.all_points = []
|
||||
tic = time.time()
|
||||
for cate_idx, subd in enumerate(self.subdirs):
|
||||
# NOTE: [subd] here is synset id
|
||||
sub_path = os.path.join(root_dir, subd, self.split)
|
||||
if not os.path.isdir(sub_path):
|
||||
print("Directory missing : %s " % (sub_path))
|
||||
raise ValueError('check the data path')
|
||||
continue
|
||||
if True:
|
||||
all_mids = []
|
||||
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
|
||||
for x in os.listdir(sub_path):
|
||||
if not x.endswith('.npy'):
|
||||
continue
|
||||
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
|
||||
|
||||
logger.info('[DATA] number of file [{}] under: {} ',
|
||||
len(os.listdir(sub_path)), sub_path)
|
||||
# NOTE: [mid] contains the split: i.e. "train/<mid>"
|
||||
# or "val/<mid>" or "test/<mid>"
|
||||
all_mids = sorted(all_mids)
|
||||
for mid in all_mids:
|
||||
# obj_fname = os.path.join(sub_path, x)
|
||||
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
|
||||
point_cloud = np.load(obj_fname) # (15k, 3)
|
||||
self.all_points.append(point_cloud[np.newaxis, ...])
|
||||
self.cate_idx_lst.append(cate_idx)
|
||||
self.all_cate_mids.append((subd, mid))
|
||||
|
||||
logger.info('[DATA] Load data time: {:.1f}s | dir: {} | '
|
||||
'sample_with_replacement: {}; num points: {}', time.time() - tic, self.subdirs,
|
||||
self.sample_with_replacement, len(self.all_points))
|
||||
|
||||
# Shuffle the index deterministically (based on the number of examples)
|
||||
self.shuffle_idx = list(range(len(self.all_points)))
|
||||
random.Random(38383).shuffle(self.shuffle_idx)
|
||||
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
|
||||
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
|
||||
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
|
||||
|
||||
# Normalization
|
||||
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
|
||||
self.normalize_per_shape = normalize_per_shape
|
||||
self.normalize_std_per_axis = normalize_std_per_axis
|
||||
self.recenter_per_shape = recenter_per_shape
|
||||
if self.normalize_shape_box: # per shape normalization
|
||||
B, N = self.all_points.shape[:2]
|
||||
self.all_points_mean = ( # B,1,3
|
||||
(np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
|
||||
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2
|
||||
self.all_points_std = np.amax( # B,1,1
|
||||
((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
|
||||
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
|
||||
axis=-1).reshape(B, 1, 1) / 2
|
||||
elif self.normalize_per_shape: # per shape normalization
|
||||
B, N = self.all_points.shape[:2]
|
||||
self.all_points_mean = self.all_points.mean(axis=1).reshape(
|
||||
B, 1, input_dim)
|
||||
logger.info('all_points shape: {}. mean over axis=1',
|
||||
self.all_points.shape)
|
||||
if normalize_std_per_axis:
|
||||
self.all_points_std = self.all_points.reshape(
|
||||
B, N, -1).std(axis=1).reshape(B, 1, input_dim)
|
||||
else:
|
||||
self.all_points_std = self.all_points.reshape(
|
||||
B, -1).std(axis=1).reshape(B, 1, 1)
|
||||
elif all_points_mean is not None and all_points_std is not None and not self.recenter_per_shape:
|
||||
# using loaded dataset stats
|
||||
self.all_points_mean = all_points_mean
|
||||
self.all_points_std = all_points_std
|
||||
elif self.recenter_per_shape: # per shape center
|
||||
# TODO: bounding box scale at the large dim and center
|
||||
B, N = self.all_points.shape[:2]
|
||||
self.all_points_mean = (
|
||||
(np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
|
||||
(np.amin(self.all_points, axis=1)).reshape(B, 1,
|
||||
input_dim)) / 2
|
||||
self.all_points_std = np.amax(
|
||||
((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
|
||||
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
|
||||
axis=-1).reshape(B, 1, 1) / 2
|
||||
# else: # normalize across the dataset
|
||||
elif normalize_global: # normalize across the dataset
|
||||
self.all_points_mean = self.all_points.reshape(
|
||||
-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
|
||||
|
||||
if normalize_std_per_axis:
|
||||
self.all_points_std = self.all_points.reshape(
|
||||
-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
|
||||
else:
|
||||
self.all_points_std = self.all_points.reshape(-1).std(
|
||||
axis=0).reshape(1, 1, 1)
|
||||
|
||||
logger.info('[DATA] normalize_global: mean={}, std={}',
|
||||
self.all_points_mean.reshape(-1),
|
||||
self.all_points_std.reshape(-1))
|
||||
else:
|
||||
raise NotImplementedError('No Normalization')
|
||||
self.all_points = (self.all_points - self.all_points_mean) / \
|
||||
self.all_points_std
|
||||
logger.info('[DATA] shape={}, all_points_mean:={}, std={}, max={:.3f}, min={:.3f}; num-pts={}',
|
||||
self.all_points.shape,
|
||||
self.all_points_mean.shape, self.all_points_std.shape,
|
||||
self.all_points.max(), self.all_points.min(), tr_sample_size)
|
||||
|
||||
if OVERFIT:
|
||||
self.all_points = self.all_points[:40]
|
||||
|
||||
# TODO: why do we need this??
|
||||
self.train_points = self.all_points[:, :min(
|
||||
10000, self.all_points.shape[1])]
|
||||
self.tr_sample_size = min(10000, tr_sample_size)
|
||||
self.te_sample_size = min(5000, te_sample_size)
|
||||
assert self.scale == 1, "Scale (!= 1) is deprecated"
|
||||
|
||||
# Default display axis order
|
||||
self.display_axis_order = [0, 1, 2]
|
||||
|
||||
def get_pc_stats(self, idx):
|
||||
if self.recenter_per_shape:
|
||||
m = self.all_points_mean[idx].reshape(1, self.input_dim)
|
||||
s = self.all_points_std[idx].reshape(1, -1)
|
||||
return m, s
|
||||
|
||||
if self.normalize_per_shape or self.normalize_shape_box:
|
||||
m = self.all_points_mean[idx].reshape(1, self.input_dim)
|
||||
s = self.all_points_std[idx].reshape(1, -1)
|
||||
return m, s
|
||||
|
||||
return self.all_points_mean.reshape(1, -1), \
|
||||
self.all_points_std.reshape(1, -1)
|
||||
|
||||
def renormalize(self, mean, std):
|
||||
self.all_points = self.all_points * self.all_points_std + \
|
||||
self.all_points_mean
|
||||
self.all_points_mean = mean
|
||||
self.all_points_std = std
|
||||
self.all_points = (self.all_points - self.all_points_mean) / \
|
||||
self.all_points_std
|
||||
self.train_points = self.all_points[:, :min(
|
||||
10000, self.all_points.shape[1])]
|
||||
## self.test_points = self.all_points[:, 10000:]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.train_points)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = {}
|
||||
tr_out = self.train_points[idx]
|
||||
if self.random_subsample and self.sample_with_replacement:
|
||||
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
|
||||
elif self.random_subsample and not self.sample_with_replacement:
|
||||
tr_idxs = np.random.permutation(
|
||||
np.arange(tr_out.shape[0]))[:self.tr_sample_size]
|
||||
else:
|
||||
tr_idxs = np.arange(self.tr_sample_size)
|
||||
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
|
||||
m, s = self.get_pc_stats(idx)
|
||||
|
||||
cate_idx = self.cate_idx_lst[idx]
|
||||
sid, mid = self.all_cate_mids[idx]
|
||||
input_pts = tr_out
|
||||
|
||||
output.update(
|
||||
{
|
||||
'idx': idx,
|
||||
'select_idx': tr_idxs,
|
||||
'tr_points': tr_out,
|
||||
'input_pts': input_pts,
|
||||
'mean': m,
|
||||
'std': s,
|
||||
'cate_idx': cate_idx,
|
||||
'sid': sid,
|
||||
'mid': mid,
|
||||
'display_axis_order': self.display_axis_order
|
||||
})
|
||||
return output
|
||||
|
||||
|
||||
def init_np_seed(worker_id):
|
||||
seed = torch.initial_seed()
|
||||
np.random.seed(seed % 4294967296)
|
||||
|
||||
|
||||
def get_datasets(cfg, args):
|
||||
"""
|
||||
cfg: config.data sub part
|
||||
"""
|
||||
if OVERFIT:
|
||||
random_subsample = 0
|
||||
else:
|
||||
random_subsample = cfg.random_subsample
|
||||
logger.info(f'get_datasets: tr_sample_size={cfg.tr_max_sample_points}, '
|
||||
f' te_sample_size={cfg.te_max_sample_points}; '
|
||||
f' random_subsample={random_subsample}'
|
||||
f' normalize_global={cfg.normalize_global}'
|
||||
f' normalize_std_per_axix={cfg.normalize_std_per_axis}'
|
||||
f' normalize_per_shape={cfg.normalize_per_shape}'
|
||||
f' recenter_per_shape={cfg.recenter_per_shape}'
|
||||
)
|
||||
kwargs = {}
|
||||
tr_dataset = ShapeNet15kPointClouds(
|
||||
categories=cfg.cates,
|
||||
split='train',
|
||||
tr_sample_size=cfg.tr_max_sample_points,
|
||||
te_sample_size=cfg.te_max_sample_points,
|
||||
sample_with_replacement=cfg.sample_with_replacement,
|
||||
scale=cfg.dataset_scale, # root_dir=cfg.data_dir,
|
||||
normalize_shape_box=cfg.normalize_shape_box,
|
||||
normalize_per_shape=cfg.normalize_per_shape,
|
||||
normalize_std_per_axis=cfg.normalize_std_per_axis,
|
||||
normalize_global=cfg.normalize_global,
|
||||
recenter_per_shape=cfg.recenter_per_shape,
|
||||
random_subsample=random_subsample,
|
||||
**kwargs)
|
||||
|
||||
eval_split = getattr(args, "eval_split", "val")
|
||||
# te_dataset has random_subsample as False, therefore not using sample_with_replacement
|
||||
te_dataset = ShapeNet15kPointClouds(
|
||||
categories=cfg.cates,
|
||||
split=eval_split,
|
||||
tr_sample_size=cfg.tr_max_sample_points,
|
||||
te_sample_size=cfg.te_max_sample_points,
|
||||
scale=cfg.dataset_scale, # root_dir=cfg.data_dir,
|
||||
normalize_shape_box=cfg.normalize_shape_box,
|
||||
normalize_per_shape=cfg.normalize_per_shape,
|
||||
normalize_std_per_axis=cfg.normalize_std_per_axis,
|
||||
normalize_global=cfg.normalize_global,
|
||||
recenter_per_shape=cfg.recenter_per_shape,
|
||||
all_points_mean=tr_dataset.all_points_mean,
|
||||
all_points_std=tr_dataset.all_points_std,
|
||||
)
|
||||
return tr_dataset, te_dataset
|
||||
|
||||
|
||||
def get_data_loaders(cfg, args):
|
||||
tr_dataset, te_dataset = get_datasets(cfg, args)
|
||||
kwargs = {}
|
||||
if args.distributed:
|
||||
kwargs['sampler'] = data.distributed.DistributedSampler(
|
||||
tr_dataset, shuffle=True)
|
||||
else:
|
||||
kwargs['shuffle'] = True
|
||||
if args.eval_trainnll:
|
||||
kwargs['shuffle'] = False
|
||||
train_loader = data.DataLoader(dataset=tr_dataset,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
drop_last=cfg.train_drop_last == 1,
|
||||
pin_memory=False, **kwargs)
|
||||
test_loader = data.DataLoader(dataset=te_dataset,
|
||||
batch_size=cfg.batch_size_test,
|
||||
shuffle=False,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
)
|
||||
logger.info(
|
||||
f'[Batch Size] train={cfg.batch_size}, test={cfg.batch_size_test}; drop-last={cfg.train_drop_last}')
|
||||
loaders = {
|
||||
"test_loader": test_loader,
|
||||
'train_loader': train_loader,
|
||||
}
|
||||
return loaders
|
450
default_config.py
Normal file
450
default_config.py
Normal file
|
@ -0,0 +1,450 @@
|
|||
# ---------------------------------------------------------------
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
|
||||
from third_party.yacs_config import CfgNode as CN
|
||||
|
||||
cfg = CN()
|
||||
cfg.dpm_ckpt = ''
|
||||
cfg.clipforge = CN()
|
||||
cfg.clipforge.clip_model = "ViT-B/32"
|
||||
cfg.clipforge.enable = 0
|
||||
cfg.clipforge.feat_dim = 512
|
||||
cfg.eval_trainnll = 0
|
||||
cfg.exp_name = ''
|
||||
cfg.cmt = ''
|
||||
cfg.hash = ''
|
||||
cfg.ngpu = 1
|
||||
cfg.snapshot_min = 30 # snapshot every 30 min
|
||||
cfg.bash_name = ''
|
||||
cfg.set_detect_anomaly = 0
|
||||
cfg.weight_recont = 1.0
|
||||
# vae ckpt
|
||||
# lns
|
||||
cfg.use_checkpoint = 0
|
||||
cfg.num_val_samples = 16 # 24 #12
|
||||
|
||||
# config for pointtransformer
|
||||
cfg.eval = CN()
|
||||
cfg.eval.need_denoise = 0
|
||||
cfg.eval.load_other_vae_ckpt = 0
|
||||
cfg.register_deprecated_key('eval.other_vae_ckpt_path')
|
||||
cfg.vis_latent_point = 0
|
||||
cfg.latent_pts = CN()
|
||||
#cfg.latent_pts.class_embed_layer = ''
|
||||
cfg.register_deprecated_key('latent_pts.class_embed_layer')
|
||||
cfg.latent_pts.style_dim = 128 # dim of global style latent variable
|
||||
cfg.register_deprecated_key('latent_pts.perturb_input')
|
||||
cfg.register_deprecated_key('latent_pts.perturb_input_scale')
|
||||
cfg.register_deprecated_key('latent_pts.outlier_input')
|
||||
|
||||
# scale of init weights for the mlp in adaGN layer
|
||||
cfg.latent_pts.ada_mlp_init_scale = 1.0
|
||||
# models.latent_points_ada.StyleMLP' # style mlp layers
|
||||
cfg.latent_pts.style_mlp = ''
|
||||
cfg.latent_pts.pts_sigma_offset = 0.0
|
||||
cfg.latent_pts.skip_weight = 0.1
|
||||
cfg.latent_pts.encoder_layer_out_dim = 32
|
||||
cfg.latent_pts.decoder_layer_out_dim = 32
|
||||
cfg.register_deprecated_key('latent_pts.encoder_nneighbor')
|
||||
cfg.register_deprecated_key('latent_pts.decoder_nneighbor')
|
||||
cfg.latent_pts.style_prior = 'models.score_sde.resnet.PriorSEDrop'
|
||||
cfg.latent_pts.mask_out_extra_latent = 0 # use only latent coordinates
|
||||
# latent coordinates directly same as input (not using the decoder and encoder)
|
||||
cfg.register_deprecated_key('latent_pts.latent_as_pts')
|
||||
|
||||
cfg.latent_pts.normalization = 'bn' # BatchNorm or LayerNorm
|
||||
cfg.latent_pts.pvd_mse_loss = 0
|
||||
cfg.latent_pts.hid = 64
|
||||
|
||||
cfg.register_deprecated_key('latent_pts.knn')
|
||||
cfg.register_deprecated_key('latent_pts.n5layer')
|
||||
cfg.register_deprecated_key('latent_pts.dgcnn_last_hid')
|
||||
|
||||
cfg.latent_pts.latent_dim_ext = [64] # the global latent dim
|
||||
cfg.latent_pts.weight_kl_pt = 1.0 # kl ratio of the pts
|
||||
cfg.latent_pts.weight_kl_feat = 1.0 # kl ratio of the latent feat
|
||||
cfg.latent_pts.weight_kl_glb = 1.0 # kl ratio of the latent feat
|
||||
# kl ratio of the latent feat
|
||||
cfg.latent_pts.style_encoder = 'models.shapelatent_modules.PointNetPlusEncoder'
|
||||
cfg.latent_pts.use_linear_for_adagn = 0
|
||||
# cfg.latent_pts.weight_kl_glb = 1.0 # kl ratio of the global latent
|
||||
|
||||
# shapelatent:
|
||||
cfg.has_shapelatent = 1
|
||||
cfg.shapelatent = CN()
|
||||
cfg.shapelatent.local_emb_agg = 'mean'
|
||||
cfg.shapelatent.freeze_vae = 0 # learn vae
|
||||
cfg.shapelatent.eps_z_global_only = 1
|
||||
cfg.shapelatent.model = 'flow'
|
||||
cfg.shapelatent.residual = 1
|
||||
cfg.shapelatent.encoder_type = 'pointnet'
|
||||
cfg.shapelatent.prior_type = 'flow'
|
||||
cfg.shapelatent.decoder_type = 'PointwiseNet'
|
||||
cfg.shapelatent.loss0_weight = 1.0
|
||||
cfg.shapelatent.latent_dim = 256
|
||||
cfg.shapelatent.kl_weight = 1e-3
|
||||
cfg.shapelatent.decoder_num_points = -1
|
||||
# offset the sigma towards zero for better init, will use the log_sigma - offset value, better to be positive s.t. - offset < 0 since we'd like to push it towards 0; exp(-0.1)=0.9, exp(-0.8)=0.44, exp(-1)=0.3, exp(-10)=4e-5
|
||||
cfg.shapelatent.log_sigma_offset = 0.0
|
||||
|
||||
cfg.sde = CN()
|
||||
cfg.sde.ode_sample = 0 #1
|
||||
# train the prior or not, default is 1, only when we do voxel2pts, will freeze prior
|
||||
cfg.sde.train_dae = 1
|
||||
cfg.sde.init_t = 1.0 # start from time = 1.0
|
||||
cfg.sde.nhead = 4 # number of head in transformder: multi-head attention layer
|
||||
cfg.sde.local_prior = 'same_as_global' # architecture for local prior
|
||||
cfg.sde.drop_inactive_var = 0
|
||||
cfg.sde.learn_mixing_logit = 1 # freeze it
|
||||
cfg.sde.regularize_mlogit_margin = 0.0
|
||||
cfg.sde.share_mlogit = 0 # use same mlogit for all latent variables
|
||||
cfg.sde.hypara_mixing_logit = 0 # set as hyper-parameter and freeze it?
|
||||
cfg.sde.bound_mlogit = 0 # clamp or not
|
||||
cfg.sde.bound_mlogit_value = -5.42 # clamp the max value
|
||||
cfg.sde.regularize_mlogit = 0 # set the sum of sigmoid(mlogit) as one loss
|
||||
cfg.sde.attn_mhead = 0 # use multi-head attention in prior model
|
||||
cfg.sde.attn_mhead_local = -1 # use multi-head attention in prior model
|
||||
cfg.sde.pos_embed = 'none'
|
||||
cfg.sde.hier_prior = 0
|
||||
cfg.sde.is_continues = 0
|
||||
cfg.sde.time_emb_scales = 1.0 # -> 1k?
|
||||
cfg.sde.time_eps = 1e-2
|
||||
cfg.sde.ode_eps = 1e-5 # cut off for ode sampling
|
||||
cfg.sde.sde_type = 'vpsde' # vada
|
||||
cfg.sde.sigma2_0 = 0.0
|
||||
cfg.sde.sigma2_max = 0.99
|
||||
cfg.sde.sigma2_min = 1e-4
|
||||
cfg.sde.beta_start = 0.1 # 1e-4 * 1e3
|
||||
cfg.sde.beta_end = 20.0 # 1e-2 * 1e3
|
||||
# sampling, always iw # ll: small times; 'll_uniform' # -> ll_iw
|
||||
cfg.sde.iw_sample_p = 'll_iw'
|
||||
# drop_all_iw / drop_sigma2t_iw
|
||||
cfg.sde.iw_subvp_like_vp_sde = False
|
||||
cfg.sde.prior_model = 'models.latent_points_ada_localprior.PVCNN2Prior'
|
||||
|
||||
# -- to train diffusion in latent space -- #
|
||||
cfg.sde.update_q_ema = False
|
||||
cfg.sde.iw_sample_q = 'reweight_p_samples'
|
||||
# ll_iw / reweight_p_samples
|
||||
cfg.sde.kl_anneal_portion_vada = 0.1
|
||||
cfg.sde.kl_const_portion_vada = 0.0
|
||||
cfg.sde.kl_const_coeff_vada = 0.7
|
||||
cfg.sde.kl_balance_vada = False
|
||||
cfg.sde.grad_clip_max_norm = 0.0
|
||||
cfg.sde.cont_kl_anneal = True
|
||||
# False
|
||||
cfg.sde.mixing_logit_init = -6
|
||||
cfg.sde.weight_decay_norm_vae = 0.0 #1e-2
|
||||
cfg.sde.weight_decay_norm_dae = 0.0 #1e-2
|
||||
# -> 0, for sn calculator
|
||||
cfg.sde.train_vae = True
|
||||
cfg.sde.jac_reg_coeff = 0
|
||||
cfg.sde.jac_reg_freq = 1
|
||||
cfg.sde.kin_reg_coeff = 0
|
||||
cfg.sde.learning_rate_mlogit = -1.0
|
||||
cfg.sde.learning_rate_dae_local = 3e-4
|
||||
cfg.sde.learning_rate_min_dae_local = 3e-4
|
||||
cfg.sde.learning_rate_dae = 3e-4
|
||||
cfg.sde.learning_rate_min_dae = 3e-4
|
||||
cfg.sde.learning_rate_min_vae = 1e-5
|
||||
cfg.sde.learning_rate_vae = 1e-4
|
||||
cfg.sde.epochs = 800
|
||||
cfg.sde.warmup_epochs = 20
|
||||
cfg.sde.weight_decay = 3e-4
|
||||
cfg.sde.use_adamax = False
|
||||
cfg.sde.use_adam = True # False
|
||||
cfg.sde.mixed_prediction = False # True
|
||||
cfg.sde.vae_checkpoint = ''
|
||||
cfg.sde.dae_checkpoint = ''
|
||||
# will be used to multiply with the t value, if ode solver, use 1k, if discrete solver, use 1.0
|
||||
cfg.sde.embedding_scale = 1.0 # 1000.0
|
||||
cfg.sde.embedding_type = 'positional'
|
||||
cfg.sde.train_ode_solver_tol = 1e-5
|
||||
cfg.sde.num_scales_dae = 2
|
||||
cfg.sde.autocast_train = False
|
||||
cfg.sde.diffusion_steps = 1000
|
||||
cfg.sde.embedding_dim = 128
|
||||
cfg.sde.num_channels_dae = 256
|
||||
cfg.sde.num_cell_per_scale_dae = 8
|
||||
cfg.sde.num_cell_per_scale_dae_local = 0
|
||||
cfg.sde.dropout = 0.2
|
||||
cfg.sde.num_preprocess_blocks = 2
|
||||
cfg.sde.num_latent_scales = 1
|
||||
cfg.sde.fir = False
|
||||
cfg.sde.progressive = 'none'
|
||||
cfg.sde.progressive_input = 'none'
|
||||
cfg.sde.progressive_combine = 'sum'
|
||||
cfg.sde.dataset = 'shape'
|
||||
cfg.sde.denoising_stddevs = 'beta'
|
||||
cfg.sde.ema_decay = 0.9999
|
||||
# cfg.sde.is_train_vae=True
|
||||
cfg.register_deprecated_key("sde.is_train_vae")
|
||||
cfg.sde.kl_max_coeff_vada = 1.0
|
||||
# conditional prior input
|
||||
cfg.sde.condition_add = 1
|
||||
cfg.sde.condition_cat = 0
|
||||
cfg.sde.global_prior_ckpt = '' # checkpoint for global prior component
|
||||
cfg.sde.pool_feat_cat = 0 # the local prior aggregate the feat as extra input channels
|
||||
|
||||
# hyperparameter of ddim sampling
|
||||
cfg.sde.ddim_skip_type = 'uniform'
|
||||
cfg.sde.ddim_kappa = 1.0 # 1.0: fully ddpm sampling; 0: ode style sampling
|
||||
|
||||
cfg.ddpm = CN()
|
||||
cfg.ddpm.use_p2_weight = 0
|
||||
cfg.ddpm.p2_k = 1.0
|
||||
cfg.ddpm.p2_gamma = 1.0
|
||||
cfg.ddpm.use_new_timeemb = 0
|
||||
cfg.ddpm.input_dim = 3
|
||||
cfg.ddpm.dropout = 0.1
|
||||
cfg.ddpm.num_layers_classifier = 3
|
||||
cfg.ddpm.use_bn = True
|
||||
cfg.ddpm.add_point_feat = True
|
||||
cfg.ddpm.use_gn = False
|
||||
cfg.ddpm.time_dim = 64
|
||||
cfg.ddpm.ema = 1
|
||||
cfg.ddpm.with_se = 0
|
||||
cfg.ddpm.use_global_attn = 0
|
||||
cfg.ddpm.num_steps = 1000
|
||||
cfg.ddpm.beta_1 = 1e-4
|
||||
cfg.ddpm.beta_T = 2e-2
|
||||
# ['linear', 'customer'] 'customer' for airplane in PVD
|
||||
cfg.ddpm.sched_mode = 'linear'
|
||||
cfg.ddpm.model_var_type = 'fixedlarge'
|
||||
# define architecture:
|
||||
cfg.register_deprecated_key("ddpm.pointnet_plus")
|
||||
cfg.register_deprecated_key("ddpm.pointnet_pp")
|
||||
cfg.register_deprecated_key("ddpm.pointnet_luo")
|
||||
# end define architecture
|
||||
#cfg.ddpm.use_pvc = 1
|
||||
cfg.register_deprecated_key("ddpm.use_pvc")
|
||||
cfg.ddpm.clip_denoised = 0
|
||||
cfg.ddpm.model_mean_type = 'eps'
|
||||
cfg.ddpm.loss_type = 'mse'
|
||||
cfg.ddpm.loss_type_0 = ''
|
||||
cfg.ddpm.loss_weight_emd = 0.02
|
||||
cfg.ddpm.loss_weight_cdnorm = 1.0
|
||||
cfg.ddpm.attn = [0, 1, 0, 0]
|
||||
cfg.ddpm.ncenter = [1024, 256, 64, 16]
|
||||
|
||||
#cfg.ddpm.pvc = CN()
|
||||
#cfg.ddpm.pvc.use_small_model = 0
|
||||
#cfg.ddpm.pvc.mlp_after_pvc = 0
|
||||
cfg.register_deprecated_key("ddpm.pvc")
|
||||
cfg.register_deprecated_key("ddpm.pvc.use_small_model")
|
||||
cfg.register_deprecated_key("ddpm.pvc.mlp_after_pvc")
|
||||
|
||||
cfg.ddpm.ddim_step = 200
|
||||
|
||||
cfg.data = CN()
|
||||
cfg.data.nclass = 55
|
||||
cfg.data.cond_on_cat = 0
|
||||
cfg.data.cond_on_voxel = 0
|
||||
cfg.data.eval_test_split = 0 # eval loader will be using test split
|
||||
cfg.data.voxel_size = 0.1 # size of voxel for voxel_datasets.py
|
||||
cfg.data.noise_std = 0.1 # std for the noise added to the input data
|
||||
cfg.data.noise_type = 'normal' # std for the noise added to the input data
|
||||
cfg.data.noise_std_min = -1.0 # for range of noise std
|
||||
cfg.data.clip_forge_enable = 0
|
||||
cfg.data.clip_model = 'ViT-B/32'
|
||||
cfg.data.type = "datasets.pointflow_datasets"
|
||||
# datasets/neuralspline_datasets datasets/shape_curvature
|
||||
cfg.data.dataset_type = "shapenet15k"
|
||||
cfg.data.num_workers = 12 # 8
|
||||
cfg.data.train_drop_last = 1 # drop_last for train data loader
|
||||
cfg.data.cates = 'chair' # data category
|
||||
cfg.data.tr_max_sample_points = 2048
|
||||
cfg.data.te_max_sample_points = 2048
|
||||
cfg.data.data_dir = "data/ShapeNetCore.v2.PC15k" # depreciated
|
||||
cfg.data.batch_size = 12
|
||||
cfg.data.batch_size_test = 10
|
||||
cfg.data.dataset_scale = 1
|
||||
# -- the following option in terms of normalization should turn into string -- #
|
||||
cfg.data.normalize_per_shape = False
|
||||
cfg.data.normalize_shape_box = False
|
||||
cfg.data.normalize_global = False
|
||||
cfg.data.normalize_std_per_axis = False
|
||||
cfg.data.normalize_range = False # not used
|
||||
cfg.data.recenter_per_shape = True
|
||||
# -- for the normal prediction model, used in folder_datasets
|
||||
cfg.register_deprecated_key('data.load_point_stat')
|
||||
cfg.register_deprecated_key('data.is_load_pointflow2NS')
|
||||
cfg.register_deprecated_key('data.data_path')
|
||||
|
||||
#
|
||||
cfg.data.sample_with_replacement = 1
|
||||
# fixed the data.tr_max_sample_points $np data.te_max_sample_points $np2048 points of the first 15k points
|
||||
cfg.data.random_subsample = 1
|
||||
# the data dim, used in dataset worker, if -1, it will be the same as ddpm.input_dim
|
||||
cfg.data.input_dim = -1
|
||||
cfg.data.is_encode_whole_dataset_trainer = 0
|
||||
cfg.register_deprecated_key('data.augment')
|
||||
cfg.register_deprecated_key('data.aug_translate')
|
||||
cfg.register_deprecated_key('data.aug_scale')
|
||||
cfg.register_deprecated_key('data.sub_train_set')
|
||||
|
||||
cfg.test_size = 660
|
||||
|
||||
cfg.viz = CN()
|
||||
cfg.viz.log_freq = 10
|
||||
cfg.viz.viz_freq = 400
|
||||
cfg.viz.save_freq = 200
|
||||
cfg.viz.val_freq = -1
|
||||
cfg.viz.viz_order = [2, 0, 1]
|
||||
cfg.viz.vis_sample_ddim_step = 0
|
||||
|
||||
cfg.trainer = CN()
|
||||
# when loss 1 is weighted, also weight the kl terms
|
||||
cfg.trainer.apply_loss_weight_1_kl = 0
|
||||
cfg.trainer.kl_free = [0, 0] # the value for the threshold
|
||||
# not back ward kl loss if KL value is smaller than the threshold
|
||||
cfg.trainer.use_kl_free = 0
|
||||
cfg.trainer.type = "trainers.ddpm_trainer" # it means dist trainer
|
||||
cfg.trainer.epochs = 10000
|
||||
cfg.trainer.warmup_epochs = 0
|
||||
cfg.trainer.seed = 1
|
||||
cfg.trainer.use_grad_scalar = 0
|
||||
cfg.trainer.opt = CN()
|
||||
cfg.trainer.opt.type = 'adam'
|
||||
cfg.trainer.opt.lr = 1e-4 # use bs*1e-5/8
|
||||
cfg.trainer.opt.lr_min = 1e-4 # use bs*1e-5/8
|
||||
# lr start to anneal after ratio of epochs; used in cosine and lambda lr scheduler
|
||||
cfg.trainer.opt.start_ratio = 0.6
|
||||
cfg.trainer.opt.beta1 = 0.9
|
||||
cfg.trainer.opt.beta2 = 0.999
|
||||
cfg.trainer.opt.momentum = 0.9 # for SGD
|
||||
cfg.trainer.opt.weight_decay = 0.
|
||||
cfg.trainer.opt.ema_decay = 0.9999
|
||||
cfg.trainer.opt.grad_clip = -1.
|
||||
cfg.trainer.opt.scheduler = ''
|
||||
cfg.trainer.opt.step_decay = 0.998
|
||||
cfg.trainer.opt.vae_lr_warmup_epochs = 0
|
||||
cfg.trainer.anneal_kl = 0
|
||||
cfg.trainer.kl_balance = 0
|
||||
cfg.trainer.rec_balance = 0
|
||||
cfg.trainer.loss1_weight_anneal_v = 'quad'
|
||||
cfg.trainer.kl_ratio = [1.0, 1.0]
|
||||
cfg.trainer.kl_ratio_apply = 0 # apply the fixed kl ratio in the kl_ratio list
|
||||
# using spectral norm regularization on vae training or not (used in hvae_trainer)
|
||||
cfg.trainer.sn_reg_vae = 0
|
||||
cfg.trainer.sn_reg_vae_weight = 0.0 # loss weight for the sn regulatrization
|
||||
|
||||
# [start] set in runtime
|
||||
cfg.log_name = ''
|
||||
cfg.save_dir = ''
|
||||
cfg.log_dir = ''
|
||||
cfg.comet_key = ''
|
||||
# [end]
|
||||
|
||||
cfg.voxel2pts = CN()
|
||||
cfg.voxel2pts.init_weight = ''
|
||||
cfg.voxel2pts.diffusion_steps = [0]
|
||||
|
||||
cfg.dpm = CN()
|
||||
cfg.dpm.train_encoder_only = 0
|
||||
cfg.num_ref = 0 # manully set the number of reference
|
||||
cfg.eval_ddim_step = 0 # ddim sampling for the model evaluation
|
||||
cfg.model_config = '' # used for model control, without ading new flag
|
||||
|
||||
## --- depreciated --- #
|
||||
cfg.register_deprecated_key('cls') # CN()
|
||||
cfg.register_deprecated_key('cls.classifier_type') # 'models.classifier.OneLayer'
|
||||
cfg.register_deprecated_key('cls.train_on_eps') # 1
|
||||
cfg.register_deprecated_key('cond_prior') # CN()
|
||||
cfg.register_deprecated_key('cond_prior.grid_emb_resolution') # 32
|
||||
cfg.register_deprecated_key('cond_prior.emb_dim') # 64
|
||||
cfg.register_deprecated_key('cond_prior.use_voxel_feat') # 1
|
||||
cfg.register_deprecated_key('cond_encoder_prior') # 'models.shapelatent_modules.VoxelGridEncoder'
|
||||
cfg.register_deprecated_key('cond_prior.pvcconv_concat_3d_feat_input') # 0
|
||||
cfg.register_deprecated_key('generate_mode_global') # 'interpolate'
|
||||
cfg.register_deprecated_key('generate_mode_local') # 'freeze'
|
||||
cfg.register_deprecated_key('normals') # CN()
|
||||
cfg.register_deprecated_key('normals.model_type') # ''
|
||||
cfg.register_deprecated_key('save_sample_seq_and_quit') # 0
|
||||
cfg.register_deprecated_key('lns_loss_weight') # 1.0
|
||||
cfg.register_deprecated_key('normal_pred_checkpoint') # ''
|
||||
cfg.register_deprecated_key('lns') # CN()
|
||||
cfg.register_deprecated_key('lns.override_config') # ''
|
||||
cfg.register_deprecated_key('lns.wandb_checkpoint') # 'nvidia-toronto/generative_chairs/3m3gc6sz/checkpoint-171.pth'
|
||||
cfg.register_deprecated_key('lns.num_input_points') # 1000
|
||||
cfg.register_deprecated_key('lns.num_simulate') # 20
|
||||
cfg.register_deprecated_key('lns.split_simulate') # 'train'
|
||||
# use mesh-trainer or not
|
||||
cfg.register_deprecated_key('with_lns') # 0
|
||||
|
||||
cfg.register_deprecated_key('normal_predictor_yaml') # ''
|
||||
|
||||
cfg.register_deprecated_key('pointtransformer') # CN()
|
||||
# number of attention layer in each block
|
||||
cfg.register_deprecated_key('pointtransformer.blocks') # [2, 3, 4, 6, 3]
|
||||
cfg.register_deprecated_key('shapelatent.refiner_bp') # 1 # bp gradient to the local-decoder or not
|
||||
cfg.register_deprecated_key('shapelatent.loss_weight_refiner') # 1.0 # weighted loss for the refiner
|
||||
cfg.register_deprecated_key('shapelatent.refiner_type') # 'models.pvcnn2.PVCNN2BaseAPI' # mode for the refiner
|
||||
|
||||
cfg.register_deprecated_key('shapelatent.encoder_weight_std') # 0.1
|
||||
cfg.register_deprecated_key('shapelatent.encoder_weight_norm') # 0
|
||||
cfg.register_deprecated_key('shapelatent.encoder_weight_uniform') # 1
|
||||
cfg.register_deprecated_key('shapelatent.key_point_gen') # 'mlps'
|
||||
cfg.register_deprecated_key('shapelatent.add_sub_loss') # 1 # not used
|
||||
cfg.register_deprecated_key('shapelatent.local_decoder_type') # ''
|
||||
cfg.register_deprecated_key('shapelatent.local_decoder_type_1') # ''
|
||||
cfg.register_deprecated_key('shapelatent.local_encoder_ball_radius') # 0.8
|
||||
cfg.register_deprecated_key('shapelatent.local_encoder_ap_ball_radius') # 1.0
|
||||
cfg.register_deprecated_key('shapelatent.local_encoder_type') # ''
|
||||
cfg.register_deprecated_key('shapelatent.local_encoder_type_1') # ''
|
||||
cfg.register_deprecated_key('shapelatent.local_loss_weight_max') # 50
|
||||
cfg.register_deprecated_key('shapelatent.num_neighbors') # 0
|
||||
cfg.register_deprecated_key('shapelatent.extra_centers') # []
|
||||
# for latent model is flow
|
||||
cfg.register_deprecated_key('shapelatent.latent_flow_depth') # 14
|
||||
cfg.register_deprecated_key('shapelatent.latent_flow_hidden_dim') # 256
|
||||
cfg.register_deprecated_key('shapelatent.bp_to_l0') # True
|
||||
cfg.register_deprecated_key('shapelatent.global_only_epochs') # 0
|
||||
cfg.register_deprecated_key('shapelatent.center_local_points') # 1
|
||||
cfg.register_deprecated_key('shapelatent.hvae') # CN()
|
||||
# alternatively way to compute the local loss
|
||||
cfg.register_deprecated_key('shapelatent.hvae.loss_wrt_ori') # 0
|
||||
# add voxel feature to the latent space; the decoder require pvc conv or query
|
||||
cfg.register_deprecated_key('shapelatent.add_voxel2z_global') # 0
|
||||
# reuse the encoder to get local latent
|
||||
cfg.register_deprecated_key('shapelatent.query_output_local_from_enc') # 0
|
||||
# check models/shapelatent_modules where the feature will be saved as a dict
|
||||
cfg.register_deprecated_key('shapelatent.query_local_feat_layer') # 'inter_voxelfeat_0'
|
||||
# need to check the sa_blocks of the global encoder
|
||||
cfg.register_deprecated_key('shapelatent.query_local_feat_dim') # 32
|
||||
# reuse the encoder to get local latent
|
||||
cfg.register_deprecated_key('shapelatent.query_center_emd_from_enc') # 0 # reuse the encoder for center emd
|
||||
cfg.register_deprecated_key('shapelatent.prog_dec_gf') # 8 # grow_factor in VaniDecoderProg
|
||||
cfg.register_deprecated_key('shapelatent.prog_dec_gf_list') # [0, 0] # grow_factor in VaniDecoderProg
|
||||
cfg.register_deprecated_key('shapelatent.prog_dec_ne') # 2 # num_expand in VaniDecoderProg
|
||||
# increase number hirach, used by hvaemul model
|
||||
cfg.register_deprecated_key('shapelatent.num_neighbors_per_level') # [64] # number of neighbors for each level
|
||||
cfg.register_deprecated_key('shapelatent.num_level') # 1 # number of hierarchi latent space (local)
|
||||
cfg.register_deprecated_key('shapelatent.x0_target_fps') # 0 # let the target of global output as the
|
||||
cfg.register_deprecated_key('shapelatent.downsample_input_ratio') # 1.0
|
||||
# whether taking other tensor as input to local-encoder of not
|
||||
cfg.register_deprecated_key('shapelatent.local_enc_input') # 'sim'
|
||||
# local encoder take z0 as input at which location
|
||||
cfg.register_deprecated_key('shapelatent.local_encoder_condition_z0') # ''
|
||||
# output the absolution coordinates or the offset w.r.t centers
|
||||
cfg.register_deprecated_key('shapelatent.local_decoder_output_offset') # 0
|
||||
# feed coords of keypoints to the local prior model
|
||||
cfg.register_deprecated_key('shapelatent.local_prior_need_coords') # 0
|
||||
|
||||
# add the time embedding tensor to each encoder layer instead of add to first layer only
|
||||
cfg.register_deprecated_key('sde.transformer_temb2interlayer') # 0
|
||||
# normalization used in transformer encoder;
|
||||
cfg.register_deprecated_key('sde.transformer_norm_type') # 'layer_norm'
|
||||
cfg.register_deprecated_key('data.has_normal') # 0 # for datasets/pointflow_rgb.py only
|
||||
cfg.register_deprecated_key('data.has_color') # 0 # for datasets/pointflow_rgb.py only
|
||||
cfg.register_deprecated_key('data.cls_data_ratio') # 1.0 # ratio of the training data
|
||||
cfg.register_deprecated_key('data.sample_curvature') # 0 # only for datasets/shape_curvature
|
||||
cfg.register_deprecated_key('data.ratio_c') # 1.0 # only for datasets/shape_curvature
|
45
demo.py
Normal file
45
demo.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
|
||||
"""
|
||||
require diffusers-0.11.1
|
||||
"""
|
||||
import os
|
||||
import clip
|
||||
import torch
|
||||
from PIL import Image
|
||||
from default_config import cfg as config
|
||||
from models.lion import LION
|
||||
from utils.vis_helper import plot_points
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
model_path = './lion_ckpt/text2shape/chair/checkpoints/model.pt'
|
||||
model_config = './lion_ckpt/text2shape/chair/cfg.yml'
|
||||
|
||||
config.merge_from_file(model_config)
|
||||
lion = LION(config)
|
||||
lion.load_model(model_path)
|
||||
|
||||
if config.clipforge.enable:
|
||||
input_t = ["a swivel chair, five wheels"]
|
||||
device_str = 'cuda'
|
||||
clip_model, clip_preprocess = clip.load(
|
||||
config.clipforge.clip_model, device=device_str)
|
||||
text = clip.tokenize(input_t).to(device_str)
|
||||
clip_feat = []
|
||||
clip_feat.append(clip_model.encode_text(text).float())
|
||||
clip_feat = torch.cat(clip_feat, dim=0)
|
||||
print('clip_feat', clip_feat.shape)
|
||||
else:
|
||||
clip_feat = None
|
||||
output = lion.sample(1 if clip_feat is None else clip_feat.shape[0], clip_feat=clip_feat)
|
||||
pts = output['points']
|
||||
img_name = "/tmp/tmp.png"
|
||||
plot_points(pts, output_name=img_name)
|
||||
img = Image.open(img_name)
|
||||
img.show()
|
311
env.yaml
Normal file
311
env.yaml
Normal file
|
@ -0,0 +1,311 @@
|
|||
name: lion_env
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- anaconda
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=4.5=1_gnu
|
||||
- argon2-cffi=20.1.0=py38h27cfd23_1
|
||||
- async_generator=1.10=pyhd3eb1b0_0
|
||||
- attrs=21.4.0=pyhd3eb1b0_0
|
||||
- backcall=0.2.0=pyhd3eb1b0_0
|
||||
- blas=1.0=mkl
|
||||
- bleach=4.1.0=pyhd3eb1b0_0
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2020.10.14=0
|
||||
- certifi=2020.6.20=py38_0
|
||||
- cffi=1.15.0=py38hd667e15_1
|
||||
- cmake=3.18.2=ha30ef3c_0
|
||||
- cudatoolkit=11.1.74=h6bb024c_0
|
||||
- debugpy=1.5.1=py38h295c915_0
|
||||
- decorator=5.1.1=pyhd3eb1b0_0
|
||||
- defusedxml=0.7.1=pyhd3eb1b0_0
|
||||
- entrypoints=0.3=py38_0
|
||||
- expat=2.2.10=he6710b0_2
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
- giflib=5.2.1=h7b6447c_0
|
||||
- gmp=6.2.1=h2531618_2
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- importlib_metadata=4.8.2=hd3eb1b0_0
|
||||
- intel-openmp=2021.4.0=h06a4308_3561
|
||||
- ipykernel=6.4.1=py38h06a4308_1
|
||||
- ipython=7.31.1=py38h06a4308_0
|
||||
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
||||
- ipywidgets=7.6.5=pyhd3eb1b0_1
|
||||
- jedi=0.18.1=py38h06a4308_0
|
||||
- jpeg=9d=h7f8727e_0
|
||||
- jupyter_client=7.1.2=pyhd3eb1b0_0
|
||||
- jupyter_core=4.9.1=py38h06a4308_0
|
||||
- jupyterlab_pygments=0.1.2=py_0
|
||||
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
|
||||
- krb5=1.18.2=h173b8e3_0
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||
- libcurl=7.71.1=h20c2e04_1
|
||||
- libedit=3.1.20191231=h14c3975_1
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=9.3.0=h5101ec6_17
|
||||
- libgomp=9.3.0=h5101ec6_17
|
||||
- libiconv=1.15=h63c8f33_5
|
||||
- libidn2=2.3.2=h7f8727e_0
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libsodium=1.0.18=h7b6447c_0
|
||||
- libssh2=1.9.0=h1ba5d50_1
|
||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h85742a9_0
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libwebp=1.2.0=h89dd481_0
|
||||
- libwebp-base=1.2.0=h27cfd23_0
|
||||
- lz4-c=1.9.3=h295c915_1
|
||||
- markupsafe=2.0.1=py38h27cfd23_0
|
||||
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
|
||||
- mistune=0.8.4=py38h7b6447c_1000
|
||||
- mkl=2021.4.0=h06a4308_640
|
||||
- mkl-service=2.4.0=py38h7f8727e_0
|
||||
- mkl_fft=1.3.1=py38hd3c417c_0
|
||||
- mkl_random=1.2.2=py38h51133e4_0
|
||||
- nbclient=0.5.3=pyhd3eb1b0_0
|
||||
- nbconvert=6.3.0=py38h06a4308_0
|
||||
- ncurses=6.3=h7f8727e_2
|
||||
- nest-asyncio=1.5.1=pyhd3eb1b0_0
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- notebook=6.4.6=py38h06a4308_0
|
||||
- numpy=1.21.2=py38h20f2e39_0
|
||||
- numpy-base=1.21.2=py38h79a1101_0
|
||||
- olefile=0.46=pyhd3eb1b0_0
|
||||
- openh264=2.1.1=h4ff587b_0
|
||||
- openssl=1.1.1m=h7f8727e_0
|
||||
- packaging=21.3=pyhd3eb1b0_0
|
||||
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
||||
- parso=0.8.3=pyhd3eb1b0_0
|
||||
- pexpect=4.8.0=pyhd3eb1b0_3
|
||||
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
||||
- pillow=8.4.0=py38h5aabda8_0
|
||||
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
||||
- prometheus_client=0.13.1=pyhd3eb1b0_0
|
||||
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
||||
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
||||
- pycparser=2.21=pyhd3eb1b0_0
|
||||
- pygments=2.11.2=pyhd3eb1b0_0
|
||||
- python=3.8.12=h12debd9_0
|
||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||
- python-fastjsonschema=2.16.1=pyhd8ed1ab_0
|
||||
- python_abi=3.8=2_cp38
|
||||
- pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pyzmq=22.3.0=py38h295c915_2
|
||||
- readline=8.1.2=h7f8727e_1
|
||||
- rhash=1.4.0=h1ba5d50_0
|
||||
- send2trash=1.8.0=pyhd3eb1b0_1
|
||||
- six=1.16.0=pyhd3eb1b0_0
|
||||
- sqlite=3.37.2=hc218d9a_0
|
||||
- terminado=0.9.4=py38h06a4308_0
|
||||
- testpath=0.5.0=pyhd3eb1b0_0
|
||||
- tk=8.6.11=h1ccaba5_0
|
||||
- torchaudio=0.10.2=py38_cu111
|
||||
- torchvision=0.11.3=py38_cu111
|
||||
- tornado=6.1=py38h27cfd23_0
|
||||
- traitlets=5.1.1=pyhd3eb1b0_0
|
||||
- wcwidth=0.2.5=pyhd3eb1b0_0
|
||||
- webencodings=0.5.1=py38_1
|
||||
- wheel=0.37.1=pyhd3eb1b0_0
|
||||
- widgetsnbextension=3.5.1=py38_0
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- zeromq=4.3.4=h2531618_0
|
||||
- zipp=3.7.0=pyhd3eb1b0_0
|
||||
- zlib=1.2.11=h7f8727e_4
|
||||
- zstd=1.4.5=h9ceee32_0
|
||||
- pip:
|
||||
- about-time==3.1.1
|
||||
- absl-py==1.0.0
|
||||
- addict==2.4.0
|
||||
- aiohttp==3.8.1
|
||||
- aiosignal==1.2.0
|
||||
- alive-progress==2.2.0
|
||||
- antlr4-python3-runtime==4.9.3
|
||||
- anyio==3.5.0
|
||||
- astunparse==1.6.3
|
||||
- async-timeout==4.0.2
|
||||
- babel==2.9.1
|
||||
- cachetools==5.0.0
|
||||
- calmsize==0.1.3
|
||||
- ccimport==0.3.7
|
||||
- cftime==1.6.0
|
||||
- charset-normalizer==2.0.11
|
||||
- click==8.0.3
|
||||
- colorama==0.4.4
|
||||
- comet-ml==3.31.21
|
||||
- commonmark==0.9.1
|
||||
- configobj==5.0.6
|
||||
- crc32c==2.2.post0
|
||||
- cumm-cu111==0.2.8
|
||||
- cupy-cuda111==10.2.0
|
||||
- cycler==0.11.0
|
||||
- cython==0.29.20
|
||||
- dataclasses==0.6
|
||||
- deepspeed==0.6.5
|
||||
- deprecation==2.1.0
|
||||
- diffusers==0.11.1
|
||||
- docker-pycreds==0.4.0
|
||||
- drjit==0.2.1
|
||||
- dulwich==0.20.32
|
||||
- easydict==1.9
|
||||
- einops==0.4.0
|
||||
- everett==3.0.0
|
||||
- fastrlock==0.8
|
||||
- filelock==3.9.0
|
||||
- fire==0.4.0
|
||||
- flatbuffers==2.0
|
||||
- flatten-dict==0.4.2
|
||||
- fonttools==4.29.1
|
||||
- freetype-py==2.3.0
|
||||
- frozenlist==1.3.0
|
||||
- fsspec==2022.2.0
|
||||
- ftfy==6.1.1
|
||||
- future==0.18.2
|
||||
- fvcore==0.1.5.post20220512
|
||||
- gast==0.5.3
|
||||
- gitdb==4.0.9
|
||||
- gitpython==3.1.26
|
||||
- google-auth==2.6.0
|
||||
- google-auth-oauthlib==0.4.6
|
||||
- google-pasta==0.2.0
|
||||
- grapheme==0.6.0
|
||||
- grpcio==1.43.0
|
||||
- h5py==3.6.0
|
||||
- hjson==3.0.2
|
||||
- huggingface-hub==0.11.1
|
||||
- idna==3.3
|
||||
- imageio==2.15.0
|
||||
- imageio-ffmpeg==0.4.5
|
||||
- importlib-metadata==4.10.1
|
||||
- importlib-resources==5.4.0
|
||||
- iopath==0.1.10
|
||||
- jinja2==3.1.1
|
||||
- joblib==1.1.0
|
||||
- json5==0.9.6
|
||||
- jsonschema==4.4.0
|
||||
- jupyter-packaging==0.12.0
|
||||
- jupyter-server==1.15.6
|
||||
- jupyterlab==3.3.2
|
||||
- jupyterlab-server==2.11.2
|
||||
- keras==2.8.0
|
||||
- keras-preprocessing==1.1.2
|
||||
- kiwisolver==1.3.2
|
||||
- kornia==0.6.6
|
||||
- lark==1.1.2
|
||||
- libclang==14.0.1
|
||||
- llvmlite==0.39.0
|
||||
- loguru==0.6.0
|
||||
- markdown==3.3.6
|
||||
- matplotlib==3.5.1
|
||||
- matplotlib2tikz==0.7.6
|
||||
- meshio==5.3.4
|
||||
- mitsuba==3.0.1
|
||||
- mrcfile==1.3.0
|
||||
- multidict==6.0.2
|
||||
- multipledispatch==0.6.0
|
||||
- mypy-extensions==0.4.3
|
||||
- nbclassic==0.3.7
|
||||
- nbformat==5.2.0
|
||||
- nestargs==0.5.0
|
||||
- netcdf4==1.5.8
|
||||
- networkx==2.6.3
|
||||
- ninja==1.10.2.3
|
||||
- notebook-shim==0.1.0
|
||||
- numba==0.56.0
|
||||
- nvidia-ml-py3==7.352.0
|
||||
- oauthlib==3.2.0
|
||||
- omegaconf==2.2.2
|
||||
- open3d==0.15.2
|
||||
- opencv-python==4.5.5.64
|
||||
- openexr==1.3.7
|
||||
- opt-einsum==3.3.0
|
||||
- pandas==1.4.0
|
||||
- pathtools==0.1.2
|
||||
- pccm==0.3.4
|
||||
- pip==22.3.1
|
||||
- plyfile==0.7.4
|
||||
- portalocker==2.5.1
|
||||
- progressbar2==4.0.0
|
||||
- promise==2.3
|
||||
- protobuf==3.19.4
|
||||
- psutil==5.9.0
|
||||
- py-cpuinfo==8.0.0
|
||||
- pyasn1==0.4.8
|
||||
- pyasn1-modules==0.2.8
|
||||
- pybind11==2.10.0
|
||||
- pydeprecate==0.3.1
|
||||
- pyglet==1.5.23
|
||||
- pykeops==1.5
|
||||
- pymcubes==0.1.2
|
||||
- pyopengl==3.1.0
|
||||
- pyparsing==3.0.7
|
||||
- pyquaternion==0.9.9
|
||||
- pyrr==0.10.3
|
||||
- pyrsistent==0.18.1
|
||||
- python-swiftclient==4.0.0
|
||||
- python-utils==3.3.3
|
||||
- pytorch-lightning==1.5.1
|
||||
- pytorch3d==0.3.0
|
||||
- pytz==2021.3
|
||||
- pywavelets==1.2.0
|
||||
- pyyaml==6.0
|
||||
- regex==2022.3.15
|
||||
- requests==2.27.1
|
||||
- requests-oauthlib==1.3.1
|
||||
- requests-toolbelt==0.9.1
|
||||
- rich==12.3.0
|
||||
- rsa==4.8
|
||||
- ruamel-yaml==0.17.20
|
||||
- ruamel-yaml-clib==0.2.6
|
||||
- scikit-image==0.19.1
|
||||
- scikit-learn==1.0.2
|
||||
- scipy==1.8.0
|
||||
- seaborn==0.11.2
|
||||
- semantic-version==2.9.0
|
||||
- sentry-sdk==1.5.4
|
||||
- sharedarray==3.2.1
|
||||
- shortuuid==1.0.8
|
||||
- simple-parsing==0.0.18
|
||||
- simplejson==3.18.0
|
||||
- sklearn==0.0
|
||||
- smmap==5.0.0
|
||||
- sniffio==1.2.0
|
||||
- tabulate==0.8.9
|
||||
- tensorboard==2.8.0
|
||||
- tensorboard-data-server==0.6.1
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
- tensorboardx==2.4.1
|
||||
- tensorflow-gpu==2.8.0
|
||||
- tensorflow-io-gcs-filesystem==0.25.0
|
||||
- termcolor==1.1.0
|
||||
- tf-estimator-nightly==2.8.0.dev2021122109
|
||||
- tflearn==0.5.0
|
||||
- tfrecord==1.14.1
|
||||
- threadpoolctl==3.1.0
|
||||
- tifffile==2022.2.2
|
||||
- tikzplotlib==0.10.1
|
||||
- tomlkit==0.10.0
|
||||
- torchmetrics==0.7.2
|
||||
- tqdm==4.62.3
|
||||
- trimesh==3.10.1
|
||||
- typing-extensions==4.2.0
|
||||
- typing-inspect==0.7.1
|
||||
- urllib3==1.26.8
|
||||
- wandb==0.12.10
|
||||
- webcolors==1.11.1
|
||||
- websocket-client==1.2.3
|
||||
- werkzeug==2.0.3
|
||||
- wrapt==1.13.3
|
||||
- wurlitzer==3.0.2
|
||||
- yacs==0.1.8
|
||||
- yarl==1.7.2
|
||||
- yaspin==2.1.0
|
67
models/adagn.py
Normal file
67
models/adagn.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
"""
|
||||
adaptive group norm
|
||||
"""
|
||||
from loguru import logger
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
from utils.checker import *
|
||||
from .dense import dense
|
||||
import os
|
||||
|
||||
class AdaGN(nn.Module):
|
||||
'''
|
||||
adaptive group normalization
|
||||
'''
|
||||
def __init__(self, ndim, cfg, n_channel):
|
||||
"""
|
||||
ndim: dim of the input features
|
||||
n_channel: number of channels of the inputs
|
||||
ndim_style: channel of the style features
|
||||
"""
|
||||
super().__init__()
|
||||
style_dim = cfg.latent_pts.style_dim
|
||||
init_scale = cfg.latent_pts.ada_mlp_init_scale
|
||||
self.ndim = ndim
|
||||
self.n_channel = n_channel
|
||||
self.style_dim = style_dim
|
||||
self.out_dim = n_channel * 2
|
||||
self.norm = nn.GroupNorm(8, n_channel)
|
||||
in_channel = n_channel
|
||||
self.emd = dense(style_dim, n_channel*2, init_scale=init_scale)
|
||||
self.emd.bias.data[:in_channel] = 1
|
||||
self.emd.bias.data[in_channel:] = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f"AdaGN(GN(8, {self.n_channel}), Linear({self.style_dim}, {self.out_dim}))"
|
||||
|
||||
def forward(self, image, style):
|
||||
# style: B,D
|
||||
# image: B,D,N,1
|
||||
CHECK2D(style)
|
||||
style = self.emd(style)
|
||||
if self.ndim == 3: #B,D,V,V,V
|
||||
CHECK5D(image)
|
||||
style = style.view(style.shape[0], -1, 1, 1, 1) # 5D
|
||||
elif self.ndim == 2: # B,D,N,1
|
||||
CHECK4D(image)
|
||||
style = style.view(style.shape[0], -1, 1, 1) # 4D
|
||||
elif self.ndim == 1: # B,D,N
|
||||
CHECK3D(image)
|
||||
style = style.view(style.shape[0], -1, 1) # 4D
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
factor, bias = style.chunk(2, 1)
|
||||
result = self.norm(image)
|
||||
result = result * factor + bias
|
||||
return result
|
||||
|
||||
|
80
models/dense.py
Normal file
80
models/dense.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
|
||||
""" copied and modified from https://github.com/CW-Huang/sdeflow-light/blob/524650bc5ad69522b3e0905672deef0650374512/lib/models/unet.py """
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
"""
|
||||
copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L337
|
||||
"""
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out', 'fan_avg']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_uniform_(tensor, gain=1., mode='fan_in'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
uniform distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
Also known as He initialization.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
gain: multiplier to the dispersion
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in')
|
||||
"""
|
||||
fan = _calculate_correct_fan(tensor, mode)
|
||||
# gain = calculate_gain(nonlinearity, a)
|
||||
var = gain / max(1., fan)
|
||||
bound = math.sqrt(3.0 * var) # Calculate uniform bounds from standard deviation
|
||||
with torch.no_grad():
|
||||
return tensor.uniform_(-bound, bound)
|
||||
|
||||
|
||||
def variance_scaling_init_(tensor, scale):
|
||||
return kaiming_uniform_(tensor, gain=1e-10 if scale == 0 else scale, mode='fan_avg')
|
||||
|
||||
|
||||
def dense(in_channels, out_channels, init_scale=1.):
|
||||
lin = nn.Linear(in_channels, out_channels)
|
||||
variance_scaling_init_(lin.weight, scale=init_scale)
|
||||
nn.init.zeros_(lin.bias)
|
||||
return lin
|
||||
|
||||
def conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=1, dilation=1, padding=1, bias=True, padding_mode='zeros',
|
||||
init_scale=1.):
|
||||
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
bias=bias, padding_mode=padding_mode)
|
||||
variance_scaling_init_(conv.weight, scale=init_scale)
|
||||
if bias:
|
||||
nn.init.zeros_(conv.bias)
|
||||
return conv
|
||||
|
||||
|
||||
|
37
models/distributions.py
Normal file
37
models/distributions.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@torch.jit.script
|
||||
def sample_normal_jit(mu, sigma):
|
||||
rho = mu.mul(0).normal_()
|
||||
z = rho.mul_(sigma).add_(mu)
|
||||
return z, rho
|
||||
|
||||
class Normal:
|
||||
def __init__(self, mu, log_sigma, sigma=None):
|
||||
self.mu = mu
|
||||
self.log_sigma = log_sigma
|
||||
self.sigma = torch.exp(log_sigma) if sigma is None else sigma
|
||||
|
||||
def sample(self, t=1.):
|
||||
return sample_normal_jit(self.mu, self.sigma * t)
|
||||
|
||||
def sample_given_rho(self, rho):
|
||||
return rho * self.sigma + self.mu
|
||||
|
||||
def mean(self):
|
||||
return self.mu
|
||||
|
||||
def log_p(self, samples):
|
||||
normalized_samples = (samples - self.mu) / self.sigma
|
||||
log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma
|
||||
return log_p
|
||||
|
||||
|
273
models/latent_points_ada.py
Normal file
273
models/latent_points_ada.py
Normal file
|
@ -0,0 +1,273 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .pvcnn2_ada import \
|
||||
create_pointnet2_sa_components, create_pointnet2_fp_modules, LinearAttention, create_mlp_components, SharedMLP
|
||||
|
||||
# the building block of encode and decoder for VAE
|
||||
|
||||
class PVCNN2Unet(nn.Module):
|
||||
"""
|
||||
copied and modified from https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py#L172
|
||||
"""
|
||||
def __init__(self,
|
||||
num_classes, embed_dim, use_att, dropout=0.1,
|
||||
extra_feature_channels=3,
|
||||
input_dim=3,
|
||||
width_multiplier=1,
|
||||
voxel_resolution_multiplier=1,
|
||||
time_emb_scales=1.0,
|
||||
verbose=True,
|
||||
condition_input=False,
|
||||
point_as_feat=1, cfg={},
|
||||
sa_blocks={}, fp_blocks={},
|
||||
clip_forge_enable=0,
|
||||
clip_forge_dim=512
|
||||
):
|
||||
super().__init__()
|
||||
logger.info('[Build Unet] extra_feature_channels={}, input_dim={}',
|
||||
extra_feature_channels, input_dim)
|
||||
self.input_dim = input_dim
|
||||
|
||||
self.clip_forge_enable = clip_forge_enable
|
||||
self.sa_blocks = sa_blocks
|
||||
self.fp_blocks = fp_blocks
|
||||
self.point_as_feat = point_as_feat
|
||||
self.condition_input = condition_input
|
||||
assert extra_feature_channels >= 0
|
||||
self.time_emb_scales = time_emb_scales
|
||||
self.embed_dim = embed_dim
|
||||
## assert(self.embed_dim == 0)
|
||||
if self.embed_dim > 0: # has time embedding
|
||||
# for prior model, we have time embedding, for VAE model, no time embedding
|
||||
self.embedf = nn.Sequential(
|
||||
nn.Linear(embed_dim, embed_dim),
|
||||
nn.LeakyReLU(0.1, inplace=True),
|
||||
nn.Linear(embed_dim, embed_dim),
|
||||
)
|
||||
|
||||
if self.clip_forge_enable:
|
||||
self.clip_forge_mapping = nn.Linear(clip_forge_dim, embed_dim)
|
||||
style_dim = cfg.latent_pts.style_dim
|
||||
self.style_clip = nn.Linear(style_dim + embed_dim, style_dim)
|
||||
|
||||
self.in_channels = extra_feature_channels + 3
|
||||
|
||||
sa_layers, sa_in_channels, channels_sa_features, _ = \
|
||||
create_pointnet2_sa_components(
|
||||
input_dim=input_dim,
|
||||
sa_blocks=self.sa_blocks,
|
||||
extra_feature_channels=extra_feature_channels,
|
||||
with_se=True,
|
||||
embed_dim=embed_dim, # time embedding dim
|
||||
use_att=use_att, dropout=dropout,
|
||||
width_multiplier=width_multiplier,
|
||||
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
verbose=verbose, cfg=cfg
|
||||
)
|
||||
self.sa_layers = nn.ModuleList(sa_layers)
|
||||
|
||||
self.global_att = None if not use_att else LinearAttention(channels_sa_features, 8, verbose=verbose)
|
||||
|
||||
# only use extra features in the last fp module
|
||||
sa_in_channels[0] = extra_feature_channels + input_dim - 3
|
||||
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
||||
fp_blocks=self.fp_blocks, in_channels=channels_sa_features,
|
||||
sa_in_channels=sa_in_channels,
|
||||
with_se=True, embed_dim=embed_dim,
|
||||
use_att=use_att, dropout=dropout,
|
||||
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier,
|
||||
verbose=verbose, cfg=cfg
|
||||
)
|
||||
self.fp_layers = nn.ModuleList(fp_layers)
|
||||
|
||||
layers, _ = create_mlp_components(
|
||||
in_channels=channels_fp_features,
|
||||
out_channels=[128, dropout, num_classes], # was 0.5
|
||||
classifier=True, dim=2, width_multiplier=width_multiplier,
|
||||
cfg=cfg)
|
||||
self.classifier = nn.ModuleList(layers)
|
||||
|
||||
def get_timestep_embedding(self, timesteps, device):
|
||||
if len(timesteps.shape) == 2 and timesteps.shape[1] == 1:
|
||||
timesteps = timesteps[:,0]
|
||||
assert(len(timesteps.shape) == 1), f'get shape: {timesteps.shape}'
|
||||
timesteps = timesteps * self.time_emb_scales
|
||||
|
||||
half_dim = self.embed_dim // 2
|
||||
emb = np.log(10000) / (half_dim - 1)
|
||||
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if self.embed_dim % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
|
||||
assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim])
|
||||
return emb
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
# Input: coords: B3N
|
||||
B = inputs.shape[0]
|
||||
coords = inputs[:, :self.input_dim, :].contiguous()
|
||||
features = inputs
|
||||
temb = kwargs.get('t', None)
|
||||
if temb is not None:
|
||||
t = temb
|
||||
if t.ndim == 0 and not len(t.shape) == 1:
|
||||
t = t.view(1).expand(B)
|
||||
temb = self.embedf(self.get_timestep_embedding(t, inputs.device
|
||||
))[:,:,None].expand(-1,-1,inputs.shape[-1])
|
||||
temb_ori = temb # B,embed_dim,Npoint
|
||||
|
||||
style = kwargs['style']
|
||||
if self.clip_forge_enable:
|
||||
clip_feat = kwargs['clip_feat']
|
||||
assert(clip_feat is not None), f'require clip_feat as input'
|
||||
clip_feat = self.clip_forge_mapping(clip_feat)
|
||||
style = torch.cat([style, clip_feat], dim=1).contiguous()
|
||||
style = self.style_clip(style)
|
||||
|
||||
coords_list, in_features_list = [], []
|
||||
for i, sa_blocks in enumerate(self.sa_layers):
|
||||
in_features_list.append(features)
|
||||
coords_list.append(coords)
|
||||
if i > 0 and temb is not None:
|
||||
#TODO: implement a sa_blocks forward function; check if is PVConv layer and kwargs get grid_emb, take as additional input
|
||||
features = torch.cat([features,temb],dim=1)
|
||||
features, coords, temb, _ = \
|
||||
sa_blocks ((features,
|
||||
coords, temb, style))
|
||||
else: # i == 0 or temb is None
|
||||
features, coords, temb, _ = \
|
||||
sa_blocks ((features, coords, temb, style))
|
||||
|
||||
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
||||
if self.global_att is not None:
|
||||
features = self.global_att(features)
|
||||
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
||||
if temb is not None:
|
||||
features, coords, temb, _ = fp_blocks((
|
||||
coords_list[-1-fp_idx], coords,
|
||||
torch.cat([features,temb],dim=1),
|
||||
in_features_list[-1-fp_idx], temb, style))
|
||||
else:
|
||||
features, coords, temb, _ = fp_blocks((
|
||||
coords_list[-1-fp_idx], coords,
|
||||
features,
|
||||
in_features_list[-1-fp_idx], temb, style))
|
||||
|
||||
for l in self.classifier:
|
||||
if isinstance(l, SharedMLP):
|
||||
features = l(features, style)
|
||||
else:
|
||||
features = l(features)
|
||||
return features
|
||||
|
||||
class PointTransPVC(nn.Module):
|
||||
# encoder : B,N,3 -> B,N,2*D
|
||||
sa_blocks = [ # conv_configs, sa_configs
|
||||
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
||||
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
||||
((128, 3, 8), (64, 0.4, 32, (128, 256))),
|
||||
(None, (16, 0.8, 32, (128, 128, 128))),
|
||||
]
|
||||
fp_blocks = [
|
||||
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
|
||||
((128, 128), (128, 3, 8)),
|
||||
((128, 128), (128, 2, 16)),
|
||||
((128, 128, 64), (64, 2, 32)),
|
||||
]
|
||||
|
||||
def __init__(self, zdim, input_dim, args={}):
|
||||
super().__init__()
|
||||
self.zdim = zdim
|
||||
self.layers = PVCNN2Unet(2*zdim+input_dim*2,
|
||||
embed_dim=0, use_att=1, extra_feature_channels=0,
|
||||
input_dim=args.ddpm.input_dim, cfg=args,
|
||||
sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks,
|
||||
dropout=args.ddpm.dropout)
|
||||
self.skip_weight = args.latent_pts.skip_weight
|
||||
self.pts_sigma_offset = args.latent_pts.pts_sigma_offset
|
||||
self.input_dim = input_dim
|
||||
|
||||
def forward(self, inputs):
|
||||
x, style = inputs
|
||||
B,N,D = x.shape
|
||||
output = self.layers(x.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BND
|
||||
|
||||
pt_mu_1d = output[:,:,:self.input_dim].contiguous()
|
||||
pt_sigma_1d = output[:,:,self.input_dim:2*self.input_dim].contiguous() - self.pts_sigma_offset
|
||||
|
||||
pt_mu_1d = self.skip_weight * pt_mu_1d + x
|
||||
if self.zdim > 0:
|
||||
ft_mu_1d = output[:,:,2*self.input_dim:-self.zdim].contiguous()
|
||||
ft_sigma_1d = output[:,:,-self.zdim:].contiguous()
|
||||
|
||||
mu_1d = torch.cat([pt_mu_1d, ft_mu_1d], dim=2).view(B,-1).contiguous()
|
||||
sigma_1d = torch.cat([pt_sigma_1d, ft_sigma_1d], dim=2).view(B,-1).contiguous()
|
||||
else:
|
||||
mu_1d = pt_mu_1d.view(B,-1).contiguous()
|
||||
sigma_1d = pt_sigma_1d.view(B,-1).contiguous()
|
||||
return {'mu_1d': mu_1d, 'sigma_1d': sigma_1d}
|
||||
|
||||
class LatentPointDecPVC(nn.Module):
|
||||
""" input x: [B,Npoint,D] with [B,Npoint,3]
|
||||
"""
|
||||
sa_blocks = [ # conv_configs, sa_configs
|
||||
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
||||
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
||||
((128, 3, 8), (64, 0.4, 32, (128, 256))),
|
||||
(None, (16, 0.8, 32, (128, 128, 128))),
|
||||
]
|
||||
fp_blocks = [
|
||||
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
|
||||
((128, 128), (128, 3, 8)),
|
||||
((128, 128), (128, 2, 16)),
|
||||
((128, 128, 64), (64, 2, 32)),
|
||||
]
|
||||
|
||||
def __init__(self, point_dim, context_dim, num_points=None, args={}, **kwargs):
|
||||
super().__init__()
|
||||
self.point_dim = point_dim
|
||||
logger.info('[Build Dec] point_dim={}, context_dim={}', point_dim, context_dim)
|
||||
self.context_dim = context_dim + self.point_dim
|
||||
# self.num_points = num_points
|
||||
if num_points is None:
|
||||
self.num_points = args.data.tr_max_sample_points
|
||||
else:
|
||||
self.num_points = num_points
|
||||
self.layers = PVCNN2Unet(point_dim, embed_dim=0, use_att=1,
|
||||
extra_feature_channels=context_dim,
|
||||
input_dim=args.ddpm.input_dim, cfg=args,
|
||||
sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks,
|
||||
dropout=args.ddpm.dropout)
|
||||
self.skip_weight = args.latent_pts.skip_weight
|
||||
|
||||
def forward(self, x, beta, context, style):
|
||||
"""
|
||||
Args:
|
||||
x: Point clouds at some timestep t, (B, N, d). [not used]
|
||||
beta: Time. (B, ). [not used]
|
||||
context: Latent points, (B,N_pts*D_latent_pts), D_latent_pts = D_input + D_extra
|
||||
style: Shape latents. (B,d).
|
||||
Returns:
|
||||
points: (B,N,3)
|
||||
"""
|
||||
|
||||
# CHECKDIM(context, 1, self.num_points*self.context_dim)
|
||||
assert(context.shape[1] == self.num_points*self.context_dim)
|
||||
context = context.view(-1,self.num_points,self.context_dim) # BND
|
||||
x = context[:,:,:self.point_dim]
|
||||
output = self.layers(context.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BN3
|
||||
output = output * self.skip_weight + x
|
||||
return output
|
||||
|
84
models/latent_points_ada_localprior.py
Normal file
84
models/latent_points_ada_localprior.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import torch
|
||||
from loguru import logger
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .latent_points_ada import PVCNN2Unet
|
||||
from .utils import mask_inactive_variables
|
||||
|
||||
# diffusion model for latent points
|
||||
class PVCNN2Prior(PVCNN2Unet):
|
||||
sa_blocks = [ # conv_configs, sa_configs
|
||||
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
||||
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
||||
((128, 3, 8), (64, 0.4, 32, (128, 128))),
|
||||
(None, (16, 0.8, 32, (128, 128, 128))),
|
||||
]
|
||||
fp_blocks = [
|
||||
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
|
||||
((128, 128), (128, 3, 8)),
|
||||
((128, 128), (128, 2, 16)),
|
||||
((128, 128, 64), (64, 2, 32)),
|
||||
]
|
||||
|
||||
def __init__(self, args, num_input_channels, cfg):
|
||||
|
||||
# only cfg is used
|
||||
self.clip_forge_enable = cfg.clipforge.enable
|
||||
clip_forge_dim = cfg.clipforge.feat_dim
|
||||
num_input_channels = num_classes = cfg.shapelatent.latent_dim + cfg.ddpm.input_dim
|
||||
self.num_classes = num_classes
|
||||
embed_dim = cfg.ddpm.time_dim
|
||||
use_att = True
|
||||
extra_feature_channels = cfg.shapelatent.latent_dim
|
||||
self.num_points = cfg.data.tr_max_sample_points
|
||||
dropout = cfg.ddpm.dropout
|
||||
time_emb_scales = cfg.sde.embedding_scale # 1k default
|
||||
logger.info('[Build Prior Model] nclass={}, embed_dim={}, use_att={},'
|
||||
'extra_feature_channels={}, dropout={}, time_emb_scales={} num_point={}',
|
||||
num_classes, embed_dim, use_att, extra_feature_channels, dropout, time_emb_scales,
|
||||
self.num_points)
|
||||
# Attention: we are not using time_emb_scales here, but the embedding_scale
|
||||
super().__init__(
|
||||
num_classes, embed_dim, use_att, dropout=dropout,
|
||||
input_dim=cfg.ddpm.input_dim,
|
||||
extra_feature_channels=extra_feature_channels,
|
||||
time_emb_scales=time_emb_scales,
|
||||
verbose=True,
|
||||
condition_input=False,
|
||||
cfg=cfg,
|
||||
sa_blocks=self.sa_blocks,
|
||||
fp_blocks=self.fp_blocks,
|
||||
clip_forge_enable=self.clip_forge_enable, clip_forge_dim=clip_forge_dim)
|
||||
# init mixing logit
|
||||
self.mixed_prediction = cfg.sde.mixed_prediction # This enables mixed prediction
|
||||
if self.mixed_prediction:
|
||||
logger.info('init-mixing_logit = {}, after sigmoid = {}',
|
||||
cfg.sde.mixing_logit_init, torch.sigmoid(torch.tensor(cfg.sde.mixing_logit_init))
|
||||
)
|
||||
init = cfg.sde.mixing_logit_init * torch.ones(size=[1, num_input_channels*self.num_points, 1, 1])
|
||||
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
|
||||
self.is_active = None
|
||||
else: # no mixing_logit
|
||||
self.mixing_logit = None
|
||||
self.is_active = None
|
||||
|
||||
def forward(self, x, t, *args, **kwargs): #x0=None):
|
||||
# Input: x: B,ND or B,ND,1,1
|
||||
# require shape for x: B,C,N
|
||||
## CHECKEQ(x.shape[-1], self.num_classes)
|
||||
assert('condition_input' in kwargs), 'require condition_input'
|
||||
if self.mixed_prediction and self.is_active is not None:
|
||||
x = mask_inactive_variables(x, self.is_active)
|
||||
input_shape = x.shape
|
||||
x = x.view(-1,self.num_points,self.num_classes).permute(0,2,1).contiguous()
|
||||
B = x.shape[0]
|
||||
out = super().forward(x, t=t, style=kwargs['condition_input'].squeeze(-1).squeeze(-1), clip_feat=kwargs.get('clip_feat', None))
|
||||
return out.permute(0,2,1).contiguous().view(input_shape)
|
||||
# -1,self.num_classes) # BDN -> BND -> BN,D
|
91
models/lion.py
Normal file
91
models/lion.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
from models.vae_adain import Model as VAE
|
||||
from models.latent_points_ada_localprior import PVCNN2Prior as LocalPrior
|
||||
from utils.diffusion_pvd import DiffusionDiscretized
|
||||
from utils.vis_helper import plot_points
|
||||
from utils.model_helper import import_model
|
||||
from diffusers import DDPMScheduler
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
class LION(object):
|
||||
def __init__(self, cfg):
|
||||
self.vae = VAE(cfg).cuda()
|
||||
GlobalPrior = import_model(cfg.latent_pts.style_prior)
|
||||
global_prior = GlobalPrior(cfg.sde, cfg.latent_pts.style_dim, cfg).cuda()
|
||||
local_prior = LocalPrior(cfg.sde, cfg.shapelatent.latent_dim, cfg).cuda()
|
||||
self.priors = torch.nn.ModuleList([global_prior, local_prior])
|
||||
self.scheduler = DDPMScheduler(clip_sample=False,
|
||||
beta_start=cfg.ddpm.beta_1, beta_end=cfg.ddpm.beta_T, beta_schedule=cfg.ddpm.sched_mode,
|
||||
num_train_timesteps=cfg.ddpm.num_steps, variance_type=cfg.ddpm.model_var_type)
|
||||
self.diffusion = DiffusionDiscretized(None, None, cfg)
|
||||
# self.load_model(cfg)
|
||||
|
||||
def load_model(self, model_path):
|
||||
# model_path = cfg.ckpt.path
|
||||
ckpt = torch.load(model_path)
|
||||
self.priors.load_state_dict(ckpt['dae_state_dict'])
|
||||
self.vae.load_state_dict(ckpt['vae_state_dict'])
|
||||
print(f'INFO finish loading from {model_path}')
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, num_samples=10, clip_feat=None, save_img=False):
|
||||
self.scheduler.set_timesteps(1000, device='cuda')
|
||||
timesteps = self.scheduler.timesteps
|
||||
latent_shape = self.vae.latent_shape()
|
||||
global_prior, local_prior = self.priors[0], self.priors[1]
|
||||
assert(not local_prior.mixed_prediction and not global_prior.mixed_prediction)
|
||||
sampled_list = []
|
||||
output_dict = {}
|
||||
|
||||
# start sample global prior
|
||||
x_T_shape = [num_samples] + latent_shape[0]
|
||||
x_noisy = torch.randn(size=x_T_shape, device='cuda')
|
||||
condition_input = None
|
||||
for i, t in enumerate(timesteps):
|
||||
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
|
||||
noise_pred = global_prior(x=x_noisy, t=t_tensor.float(),
|
||||
condition_input=condition_input, clip_feat=clip_feat)
|
||||
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
|
||||
sampled_list.append(x_noisy)
|
||||
output_dict['z_global'] = x_noisy
|
||||
|
||||
condition_input = x_noisy
|
||||
condition_input = self.vae.global2style(condition_input)
|
||||
|
||||
# start sample local prior
|
||||
x_T_shape = [num_samples] + latent_shape[1]
|
||||
x_noisy = torch.randn(size=x_T_shape, device='cuda')
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
|
||||
noise_pred = local_prior(x=x_noisy, t=t_tensor.float(),
|
||||
condition_input=condition_input, clip_feat=clip_feat)
|
||||
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
|
||||
sampled_list.append(x_noisy)
|
||||
output_dict['z_local'] = x_noisy
|
||||
|
||||
# decode the latent
|
||||
output = self.vae.sample(num_samples=num_samples, decomposed_eps=sampled_list)
|
||||
if save_img:
|
||||
out_name = plot_points(output, "/tmp/tmp.png")
|
||||
print(f'INFO save plot image at {out_name}')
|
||||
output_dict['points'] = output
|
||||
return output_dict
|
||||
|
||||
def get_mixing_component(self, noise_pred, t):
|
||||
# usage:
|
||||
# if global_prior.mixed_prediction:
|
||||
# mixing_component = self.get_mixing_component(noise_pred, t)
|
||||
# coeff = torch.sigmoid(global_prior.mixing_logit)
|
||||
# noise_pred = (1 - coeff) * mixing_component + coeff * noise_pred
|
||||
|
||||
alpha_bar = self.scheduler.alphas_cumprod[t]
|
||||
one_minus_alpha_bars_sqrt = np.sqrt(1.0 - alpha_bar)
|
||||
return noise_pred * one_minus_alpha_bars_sqrt
|
557
models/pvcnn2.py
Normal file
557
models/pvcnn2.py
Normal file
|
@ -0,0 +1,557 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
"""
|
||||
copied and modified from source:
|
||||
https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py
|
||||
and functions under
|
||||
https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules
|
||||
"""
|
||||
import copy
|
||||
import functools
|
||||
from loguru import logger
|
||||
from einops import rearrange
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
import third_party.pvcnn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||||
|
||||
class SE3d(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.channel = channel
|
||||
def __repr__(self):
|
||||
return f"SE({self.channel}, {self.channel})"
|
||||
def forward(self, inputs):
|
||||
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
"""
|
||||
copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159
|
||||
"""
|
||||
def __init__(self, dim, heads = 4, dim_head = 32, verbose=True):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Args:
|
||||
x: torch.tensor (B,C,N), C=num-channels, N=num-points
|
||||
Returns:
|
||||
out: torch.tensor (B,C,N)
|
||||
'''
|
||||
x = x.unsqueeze(-1) # add w dimension
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
out = self.to_out(out)
|
||||
out = out.squeeze(-1) # B,C,N,1 -> B,C,N
|
||||
return out
|
||||
|
||||
|
||||
def swish(input):
|
||||
return input * torch.sigmoid(input)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return swish(input)
|
||||
|
||||
|
||||
class BallQuery(nn.Module):
|
||||
def __init__(self, radius, num_neighbors, include_coordinates=True):
|
||||
super().__init__()
|
||||
self.radius = radius
|
||||
self.num_neighbors = num_neighbors
|
||||
self.include_coordinates = include_coordinates
|
||||
|
||||
@custom_bwd
|
||||
def backward(self, *args, **kwargs):
|
||||
return super().backward(*args, **kwargs)
|
||||
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(self, points_coords, centers_coords, points_features=None):
|
||||
# input: BCN, BCN
|
||||
# returns:
|
||||
# neighbor_features: B,D(+3),Ncenter
|
||||
points_coords = points_coords.contiguous()
|
||||
centers_coords = centers_coords.contiguous()
|
||||
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
|
||||
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
|
||||
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
||||
|
||||
if points_features is None:
|
||||
assert self.include_coordinates, 'No Features For Grouping'
|
||||
neighbor_features = neighbor_coordinates
|
||||
else:
|
||||
neighbor_features = F.grouping(points_features, neighbor_indices)
|
||||
if self.include_coordinates:
|
||||
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
|
||||
return neighbor_features
|
||||
|
||||
def extra_repr(self):
|
||||
return 'radius={}, num_neighbors={}{}'.format(
|
||||
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
||||
|
||||
class SharedMLP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dim=1):
|
||||
super().__init__()
|
||||
if dim==1:
|
||||
conv = nn.Conv1d
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
bn = nn.GroupNorm
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [out_channels]
|
||||
layers = []
|
||||
for oc in out_channels:
|
||||
layers.append( conv(in_channels, oc, 1))
|
||||
layers.append(bn(8, oc))
|
||||
layers.append(Swish())
|
||||
in_channels = oc
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
if isinstance(inputs, (list, tuple)):
|
||||
return (self.layers(inputs[0]), *inputs[1:])
|
||||
else:
|
||||
return self.layers(inputs)
|
||||
|
||||
class Voxelization(nn.Module):
|
||||
def __init__(self, resolution, normalize=True, eps=0):
|
||||
super().__init__()
|
||||
self.r = int(resolution)
|
||||
self.normalize = normalize
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, features, coords):
|
||||
# features: B,D,N
|
||||
# coords: B,3,N
|
||||
coords = coords.detach()
|
||||
norm_coords = coords - coords.mean(2, keepdim=True)
|
||||
if self.normalize:
|
||||
norm_coords = norm_coords / (norm_coords.norm(
|
||||
dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 +
|
||||
self.eps) + 0.5
|
||||
else:
|
||||
norm_coords = (norm_coords + 1) / 2.0
|
||||
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
||||
vox_coords = torch.round(norm_coords).to(torch.int32)
|
||||
if features is None:
|
||||
return features, norm_coords
|
||||
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
|
||||
|
||||
def extra_repr(self):
|
||||
return 'resolution={}{}'.format(
|
||||
self.r,
|
||||
', normalized eps = {}'.format(self.eps) if self.normalize else '')
|
||||
|
||||
class PVConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size, resolution,
|
||||
normalize=1, eps=0, with_se=False,
|
||||
add_point_feat=True, attention=False,
|
||||
dropout=0.1, verbose=True
|
||||
):
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.voxelization = Voxelization(resolution,
|
||||
normalize=normalize,
|
||||
eps=eps)
|
||||
# For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention)
|
||||
voxel_layers = [
|
||||
nn.Conv3d(in_channels,
|
||||
out_channels,
|
||||
kernel_size, stride=1,
|
||||
padding=kernel_size // 2),
|
||||
nn.GroupNorm(8, out_channels),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels, out_channels,
|
||||
kernel_size, stride=1,
|
||||
padding=kernel_size // 2),
|
||||
nn.GroupNorm(8, out_channels)
|
||||
]
|
||||
if with_se:
|
||||
voxel_layers.append(SE3d(out_channels))
|
||||
self.voxel_layers = nn.Sequential(*voxel_layers)
|
||||
if attention:
|
||||
self.attn = LinearAttention(out_channels, verbose=verbose)
|
||||
else:
|
||||
self.attn = None
|
||||
if add_point_feat:
|
||||
self.point_features = SharedMLP(in_channels, out_channels) #, **mlp_kwargs)
|
||||
self.add_point_feat = add_point_feat
|
||||
|
||||
def forward(self, inputs):
|
||||
'''
|
||||
Args:
|
||||
inputs: tuple of features and coords
|
||||
features: B,feat-dim,num-points
|
||||
coords: B,3, num-points
|
||||
Returns:
|
||||
fused_features: in (B,out-feat-dim,num-points)
|
||||
coords : in (B, 3, num_points); same as the input coords
|
||||
'''
|
||||
features = inputs[0]
|
||||
coords_input = inputs[1]
|
||||
time_emb = inputs[2]
|
||||
## features, coords_input, time_emb = inputs
|
||||
if coords_input.shape[1] > 3:
|
||||
coords = coords_input[:,:3] # the last 3 dim are other point attributes if any
|
||||
else:
|
||||
coords = coords_input
|
||||
assert (features.shape[0] == coords.shape[0]
|
||||
), f'get feat: {features.shape} and {coords.shape}'
|
||||
assert (features.shape[2] == coords.shape[2]
|
||||
), f'get feat: {features.shape} and {coords.shape}'
|
||||
assert (coords.shape[1] == 3
|
||||
), f'expect coords: B,3,Npoint, get: {coords.shape}'
|
||||
# features: B,D,N; point_features
|
||||
# coords: B,3,N
|
||||
voxel_features_4d, voxel_coords = self.voxelization(features, coords)
|
||||
r = self.resolution
|
||||
B = coords.shape[0]
|
||||
voxel_features_4d = self.voxel_layers(voxel_features_4d)
|
||||
voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords,
|
||||
r, self.training)
|
||||
|
||||
fused_features = voxel_features
|
||||
if self.add_point_feat:
|
||||
fused_features = fused_features + self.point_features(features)
|
||||
if self.attn is not None:
|
||||
fused_features = self.attn(fused_features)
|
||||
if time_emb is None:
|
||||
time_emb = {'voxel_features_4d': voxel_features_4d, 'resolution': self.resolution, 'training': self.training}
|
||||
return fused_features, coords_input, time_emb #inputs[2]
|
||||
|
||||
|
||||
class PointNetAModule(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, include_coordinates=True):
|
||||
super().__init__()
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [[out_channels]]
|
||||
elif not isinstance(out_channels[0], (list, tuple)):
|
||||
out_channels = [out_channels]
|
||||
|
||||
mlps = []
|
||||
total_out_channels = 0
|
||||
for _out_channels in out_channels:
|
||||
mlps.append(
|
||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
||||
out_channels=_out_channels, dim=1)
|
||||
)
|
||||
total_out_channels += _out_channels[-1]
|
||||
|
||||
self.include_coordinates = include_coordinates
|
||||
self.out_channels = total_out_channels
|
||||
self.mlps = nn.ModuleList(mlps)
|
||||
|
||||
def forward(self, inputs):
|
||||
features, coords, time_emb = inputs
|
||||
if self.include_coordinates:
|
||||
features = torch.cat([features, coords], dim=1)
|
||||
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
|
||||
if len(self.mlps) > 1:
|
||||
features_list = []
|
||||
for mlp in self.mlps:
|
||||
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
|
||||
return torch.cat(features_list, dim=1), coords, time_emb
|
||||
else:
|
||||
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords, time_emb
|
||||
|
||||
def extra_repr(self):
|
||||
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
||||
|
||||
|
||||
class PointNetSAModule(nn.Module):
|
||||
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
|
||||
super().__init__()
|
||||
|
||||
if not isinstance(radius, (list, tuple)):
|
||||
radius = [radius]
|
||||
if not isinstance(num_neighbors, (list, tuple)):
|
||||
num_neighbors = [num_neighbors] * len(radius)
|
||||
assert len(radius) == len(num_neighbors)
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [[out_channels]] * len(radius)
|
||||
elif not isinstance(out_channels[0], (list, tuple)):
|
||||
out_channels = [out_channels] * len(radius)
|
||||
assert len(radius) == len(out_channels)
|
||||
|
||||
groupers, mlps = [], []
|
||||
total_out_channels = 0
|
||||
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
|
||||
groupers.append(
|
||||
BallQuery(radius=_radius, num_neighbors=_num_neighbors,
|
||||
include_coordinates=include_coordinates)
|
||||
)
|
||||
# logger.info('create MLP: in_channel={}, out_channels={}',
|
||||
# in_channels + (3 if include_coordinates else 0),_out_channels)
|
||||
mlps.append(
|
||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0) ,
|
||||
out_channels=_out_channels, dim=2)
|
||||
)
|
||||
total_out_channels += _out_channels[-1]
|
||||
|
||||
self.num_centers = num_centers
|
||||
self.out_channels = total_out_channels
|
||||
self.groupers = nn.ModuleList(groupers)
|
||||
self.mlps = nn.ModuleList(mlps)
|
||||
|
||||
def forward(self, inputs):
|
||||
# features, coords, _ = inputs
|
||||
features = inputs[0]
|
||||
coords = inputs[1] # B3N
|
||||
if coords.shape[1] > 3:
|
||||
coords = coords[:,:3]
|
||||
|
||||
centers_coords = F.furthest_point_sample(coords, self.num_centers)
|
||||
# centers_coords: B,D,N
|
||||
S = centers_coords.shape[-1]
|
||||
time_emb = inputs[2]
|
||||
time_emb = time_emb[:,:,:S] if \
|
||||
time_emb is not None and type(time_emb) is not dict \
|
||||
else time_emb
|
||||
|
||||
features_list = []
|
||||
c = 0
|
||||
for grouper, mlp in zip(self.groupers, self.mlps):
|
||||
c += 1
|
||||
grouper_output = grouper(coords, centers_coords, features)
|
||||
features_list.append(
|
||||
mlp(grouper_output
|
||||
).max(dim=-1).values
|
||||
)
|
||||
if len(features_list) > 1:
|
||||
return torch.cat(features_list, dim=1), centers_coords, time_emb
|
||||
else:
|
||||
return features_list[0], centers_coords, time_emb
|
||||
|
||||
def extra_repr(self):
|
||||
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
||||
|
||||
|
||||
class PointNetFPModule(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
if len(inputs) == 4:
|
||||
points_coords, centers_coords, centers_features, time_emb = inputs
|
||||
points_features = None
|
||||
else:
|
||||
points_coords, centers_coords, centers_features, points_features, time_emb = inputs
|
||||
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
||||
if points_features is not None:
|
||||
interpolated_features = torch.cat(
|
||||
[interpolated_features, points_features], dim=1
|
||||
)
|
||||
if time_emb is not None:
|
||||
B,D,S = time_emb.shape
|
||||
N = points_coords.shape[-1]
|
||||
time_emb = time_emb[:,:,0:1].expand(-1,-1,N)
|
||||
return self.mlp(interpolated_features), points_coords, time_emb
|
||||
|
||||
def _linear_gn_relu(in_channels, out_channels):
|
||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
|
||||
|
||||
|
||||
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
|
||||
r = width_multiplier
|
||||
|
||||
if dim == 1:
|
||||
block = _linear_gn_relu
|
||||
else:
|
||||
block = SharedMLP
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [out_channels]
|
||||
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
|
||||
return nn.Sequential(), in_channels, in_channels
|
||||
|
||||
layers = []
|
||||
for oc in out_channels[:-1]:
|
||||
if oc < 1:
|
||||
layers.append(nn.Dropout(oc))
|
||||
else:
|
||||
oc = int(r * oc)
|
||||
layers.append(block(in_channels, oc))
|
||||
in_channels = oc
|
||||
if dim == 1:
|
||||
if classifier:
|
||||
layers.append(nn.Linear(in_channels, out_channels[-1]))
|
||||
else:
|
||||
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
|
||||
else:
|
||||
if classifier:
|
||||
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
|
||||
else:
|
||||
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
|
||||
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
||||
|
||||
|
||||
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
layers, concat_channels = [], 0
|
||||
c = 0
|
||||
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = k % 2 == 0 and k > 0 and p == 0
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
|
||||
with_se=with_se, normalize=normalize, eps=eps, verbose=verbose)
|
||||
|
||||
if c == 0:
|
||||
layers.append(block(in_channels, out_channels))
|
||||
else:
|
||||
layers.append(block(in_channels+embed_dim, out_channels))
|
||||
in_channels = out_channels
|
||||
concat_channels += out_channels
|
||||
c += 1
|
||||
return layers, in_channels, concat_channels
|
||||
|
||||
|
||||
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels,
|
||||
input_dim=3,
|
||||
embed_dim=64, use_att=False, force_att=0,
|
||||
dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True):
|
||||
"""
|
||||
Returns:
|
||||
in_channels: the last output channels of the sa blocks
|
||||
"""
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
in_channels = extra_feature_channels + input_dim
|
||||
|
||||
sa_layers, sa_in_channels = [], []
|
||||
c = 0
|
||||
num_centers = None
|
||||
for conv_configs, sa_configs in sa_blocks:
|
||||
k = 0
|
||||
sa_in_channels.append(in_channels)
|
||||
sa_blocks = []
|
||||
|
||||
if conv_configs is not None:
|
||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0)
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(
|
||||
PVConv, kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se,
|
||||
normalize=normalize, eps=eps, verbose=verbose)
|
||||
|
||||
if c == 0:
|
||||
sa_blocks.append(block(in_channels, out_channels))
|
||||
elif k ==0:
|
||||
sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels))
|
||||
in_channels = out_channels
|
||||
k += 1
|
||||
extra_feature_channels = in_channels
|
||||
|
||||
if sa_configs is not None:
|
||||
num_centers, radius, num_neighbors, out_channels = sa_configs
|
||||
_out_channels = []
|
||||
for oc in out_channels:
|
||||
if isinstance(oc, (list, tuple)):
|
||||
_out_channels.append([int(r * _oc) for _oc in oc])
|
||||
else:
|
||||
_out_channels.append(int(r * oc))
|
||||
out_channels = _out_channels
|
||||
if num_centers is None:
|
||||
block = PointNetAModule
|
||||
else:
|
||||
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
|
||||
num_neighbors=num_neighbors)
|
||||
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ),
|
||||
out_channels=out_channels,
|
||||
include_coordinates=True))
|
||||
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
||||
c += 1
|
||||
|
||||
if len(sa_blocks) == 1:
|
||||
sa_layers.append(sa_blocks[0])
|
||||
else:
|
||||
sa_layers.append(nn.Sequential(*sa_blocks))
|
||||
|
||||
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
||||
|
||||
|
||||
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
|
||||
dropout=0.1, has_temb=1,
|
||||
with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1,
|
||||
verbose=True):
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
fp_layers = []
|
||||
c = 0
|
||||
|
||||
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
|
||||
fp_blocks = []
|
||||
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
||||
fp_blocks.append(
|
||||
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb,
|
||||
out_channels=out_channels)
|
||||
)
|
||||
in_channels = out_channels[-1]
|
||||
|
||||
if conv_configs is not None:
|
||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se, # with_se_relu=True,
|
||||
normalize=normalize, eps=eps,
|
||||
verbose=verbose)
|
||||
|
||||
fp_blocks.append(block(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
if len(fp_blocks) == 1:
|
||||
fp_layers.append(fp_blocks[0])
|
||||
else:
|
||||
fp_layers.append(nn.Sequential(*fp_blocks))
|
||||
|
||||
c += 1
|
||||
|
||||
return fp_layers, in_channels
|
||||
|
||||
|
568
models/pvcnn2_ada.py
Normal file
568
models/pvcnn2_ada.py
Normal file
|
@ -0,0 +1,568 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
"""
|
||||
copied and modified from source:
|
||||
https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py
|
||||
and functions under
|
||||
https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules
|
||||
"""
|
||||
import copy
|
||||
import functools
|
||||
from loguru import logger
|
||||
from einops import rearrange
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
import third_party.pvcnn.functional as F
|
||||
# from utils.checker import *
|
||||
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||||
from .adagn import AdaGN
|
||||
import os
|
||||
quiet = int(os.environ.get('quiet', 0))
|
||||
class SE3d(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.channel = channel
|
||||
def __repr__(self):
|
||||
return f"SE({self.channel}, {self.channel})"
|
||||
def forward(self, inputs):
|
||||
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
"""
|
||||
copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159
|
||||
"""
|
||||
def __init__(self, dim, heads = 4, dim_head = 32, verbose=True):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Args:
|
||||
x: torch.tensor (B,C,N), C=num-channels, N=num-points
|
||||
Returns:
|
||||
out: torch.tensor (B,C,N)
|
||||
'''
|
||||
x = x.unsqueeze(-1) # add w dimension
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
out = self.to_out(out)
|
||||
out = out.squeeze(-1) # B,C,N,1 -> B,C,N
|
||||
return out
|
||||
|
||||
|
||||
def swish(input):
|
||||
return input * torch.sigmoid(input)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return swish(input)
|
||||
|
||||
|
||||
class BallQuery(nn.Module):
|
||||
def __init__(self, radius, num_neighbors, include_coordinates=True):
|
||||
super().__init__()
|
||||
self.radius = radius
|
||||
self.num_neighbors = num_neighbors
|
||||
self.include_coordinates = include_coordinates
|
||||
|
||||
@custom_bwd
|
||||
def backward(self, *args, **kwargs):
|
||||
return super().backward(*args, **kwargs)
|
||||
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(self, points_coords, centers_coords, points_features=None):
|
||||
# input: BCN, BCN
|
||||
# neighbor_features: B,D(+3),Ncenter
|
||||
points_coords = points_coords.contiguous()
|
||||
centers_coords = centers_coords.contiguous()
|
||||
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
|
||||
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
|
||||
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
||||
|
||||
if points_features is None:
|
||||
assert self.include_coordinates, 'No Features For Grouping'
|
||||
neighbor_features = neighbor_coordinates
|
||||
else:
|
||||
neighbor_features = F.grouping(points_features, neighbor_indices)
|
||||
if self.include_coordinates:
|
||||
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
|
||||
return neighbor_features
|
||||
|
||||
def extra_repr(self):
|
||||
return 'radius={}, num_neighbors={}{}'.format(
|
||||
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
||||
|
||||
class SharedMLP(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dim=1, cfg={}):
|
||||
|
||||
assert(len(cfg) > 0), cfg
|
||||
super().__init__()
|
||||
if dim==1:
|
||||
conv = nn.Conv1d
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
bn = functools.partial(AdaGN, dim, cfg)
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [out_channels]
|
||||
layers = []
|
||||
for oc in out_channels:
|
||||
layers.append(conv(in_channels, oc, 1))
|
||||
layers.append(bn(oc))
|
||||
layers.append(Swish())
|
||||
in_channels = oc
|
||||
self.layers = nn.ModuleList(layers)
|
||||
|
||||
def forward(self, *inputs):
|
||||
if len(inputs) == 1 and len(inputs[0]) == 4:
|
||||
# try to fix thwn SharedMLP is the first layer
|
||||
inputs = inputs[0]
|
||||
if len(inputs) == 1:
|
||||
raise NotImplementedError
|
||||
elif len(inputs) == 4:
|
||||
assert(len(inputs) == 4), 'input, style'
|
||||
x, _, _, style = inputs
|
||||
for l in self.layers:
|
||||
if isinstance(l, AdaGN):
|
||||
x = l(x, style)
|
||||
else:
|
||||
x = l(x)
|
||||
return (x, *inputs[1:])
|
||||
elif len(inputs) == 2:
|
||||
x, style = inputs
|
||||
for l in self.layers:
|
||||
if isinstance(l, AdaGN):
|
||||
x = l(x, style)
|
||||
else:
|
||||
x = l(x)
|
||||
return x
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class Voxelization(nn.Module):
|
||||
def __init__(self, resolution, normalize=True, eps=0):
|
||||
super().__init__()
|
||||
self.r = int(resolution)
|
||||
self.normalize = normalize
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, features, coords):
|
||||
# features: B,D,N
|
||||
# coords: B,3,N
|
||||
coords = coords.detach()
|
||||
norm_coords = coords - coords.mean(2, keepdim=True)
|
||||
if self.normalize:
|
||||
norm_coords = norm_coords / (norm_coords.norm(
|
||||
dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 +
|
||||
self.eps) + 0.5
|
||||
else:
|
||||
norm_coords = (norm_coords + 1) / 2.0
|
||||
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
||||
vox_coords = torch.round(norm_coords).to(torch.int32)
|
||||
if features is None:
|
||||
return features, norm_coords
|
||||
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
|
||||
|
||||
def extra_repr(self):
|
||||
return 'resolution={}{}'.format(
|
||||
self.r,
|
||||
', normalized eps = {}'.format(self.eps) if self.normalize else '')
|
||||
|
||||
class PVConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size, resolution,
|
||||
normalize=1, eps=0, with_se=False,
|
||||
add_point_feat=True, attention=False,
|
||||
dropout=0.1, verbose=True,
|
||||
cfg={}
|
||||
):
|
||||
super().__init__()
|
||||
assert(len(cfg) > 0), cfg
|
||||
self.resolution = resolution
|
||||
self.voxelization = Voxelization(resolution,
|
||||
normalize=normalize,
|
||||
eps=eps)
|
||||
# For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention)
|
||||
NormLayer = functools.partial(AdaGN, 3, cfg)
|
||||
voxel_layers = [
|
||||
nn.Conv3d(in_channels ,
|
||||
out_channels,
|
||||
kernel_size, stride=1,
|
||||
padding=kernel_size // 2),
|
||||
NormLayer(out_channels),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels, out_channels,
|
||||
kernel_size, stride=1,
|
||||
padding=kernel_size // 2),
|
||||
NormLayer(out_channels)
|
||||
]
|
||||
if with_se:
|
||||
voxel_layers.append(SE3d(out_channels))
|
||||
self.voxel_layers = nn.ModuleList(voxel_layers)
|
||||
if attention:
|
||||
self.attn = LinearAttention(out_channels, verbose=verbose)
|
||||
else:
|
||||
self.attn = None
|
||||
if add_point_feat:
|
||||
self.point_features = SharedMLP(in_channels, out_channels, cfg=cfg)
|
||||
self.add_point_feat = add_point_feat
|
||||
|
||||
def forward(self, inputs):
|
||||
'''
|
||||
Args:
|
||||
inputs: tuple of features and coords
|
||||
features: B,feat-dim,num-points
|
||||
coords: B,3, num-points
|
||||
time_emd: B,D; time embedding
|
||||
style: B,D; global latent
|
||||
Returns:
|
||||
fused_features: in (B,out-feat-dim,num-points)
|
||||
coords : in (B, 3 or 6, num_points); same as the input coords
|
||||
'''
|
||||
features = inputs[0]
|
||||
coords_input= inputs[1]
|
||||
time_emb = inputs[2]
|
||||
style = inputs[3]
|
||||
if coords_input.shape[1] > 3:
|
||||
coords = coords_input[:,:3]
|
||||
else:
|
||||
coords = coords_input
|
||||
assert (features.shape[0] == coords.shape[0]
|
||||
), f'get feat: {features.shape} and {coords.shape}'
|
||||
assert (features.shape[2] == coords.shape[2]
|
||||
), f'get feat: {features.shape} and {coords.shape}'
|
||||
assert (coords.shape[1] == 3
|
||||
), f'expect coords: B,3,Npoint, get: {coords.shape}'
|
||||
# features: B,D,N; point_features
|
||||
# coords: B,3,N
|
||||
voxel_features_4d, voxel_coords = self.voxelization(features, coords)
|
||||
r = self.resolution
|
||||
B = coords.shape[0]
|
||||
|
||||
for voxel_layers in self.voxel_layers:
|
||||
if isinstance(voxel_layers, AdaGN):
|
||||
voxel_features_4d = voxel_layers(voxel_features_4d, style)
|
||||
else:
|
||||
voxel_features_4d = voxel_layers(voxel_features_4d)
|
||||
voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords,
|
||||
r, self.training)
|
||||
|
||||
fused_features = voxel_features
|
||||
if self.add_point_feat:
|
||||
fused_features = fused_features + self.point_features(features, style)
|
||||
if self.attn is not None:
|
||||
fused_features = self.attn(fused_features)
|
||||
return fused_features, coords_input, time_emb, style
|
||||
|
||||
|
||||
class PointNetAModule(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, include_coordinates=True, cfg={}):
|
||||
super().__init__()
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [[out_channels]]
|
||||
elif not isinstance(out_channels[0], (list, tuple)):
|
||||
out_channels = [out_channels]
|
||||
|
||||
mlps = []
|
||||
total_out_channels = 0
|
||||
for _out_channels in out_channels:
|
||||
mlps.append(
|
||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
||||
out_channels=_out_channels, dim=1, cfg=cfg)
|
||||
)
|
||||
total_out_channels += _out_channels[-1]
|
||||
|
||||
self.include_coordinates = include_coordinates
|
||||
self.out_channels = total_out_channels
|
||||
self.mlps = nn.ModuleList(mlps)
|
||||
|
||||
def forward(self, inputs):
|
||||
features, coords, time_emb, style = inputs
|
||||
if self.include_coordinates:
|
||||
features = torch.cat([features, coords], dim=1)
|
||||
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
|
||||
if len(self.mlps) > 1:
|
||||
features_list = []
|
||||
for mlp in self.mlps:
|
||||
features_list.append(mlp(features, style).max(dim=-1, keepdim=True).values)
|
||||
return torch.cat(features_list, dim=1), coords, time_emb
|
||||
else:
|
||||
return self.mlps[0](features, style).max(dim=-1, keepdim=True).values, coords, time_emb
|
||||
|
||||
def extra_repr(self):
|
||||
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
||||
|
||||
|
||||
class PointNetSAModule(nn.Module):
|
||||
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True,
|
||||
cfg={}):
|
||||
super().__init__()
|
||||
if not isinstance(radius, (list, tuple)):
|
||||
radius = [radius]
|
||||
if not isinstance(num_neighbors, (list, tuple)):
|
||||
num_neighbors = [num_neighbors] * len(radius)
|
||||
assert len(radius) == len(num_neighbors)
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [[out_channels]] * len(radius)
|
||||
elif not isinstance(out_channels[0], (list, tuple)):
|
||||
out_channels = [out_channels] * len(radius)
|
||||
assert len(radius) == len(out_channels)
|
||||
|
||||
groupers, mlps = [], []
|
||||
total_out_channels = 0
|
||||
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
|
||||
groupers.append(
|
||||
BallQuery(radius=_radius, num_neighbors=_num_neighbors,
|
||||
include_coordinates=include_coordinates)
|
||||
)
|
||||
mlps.append(
|
||||
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
||||
out_channels=_out_channels, dim=2, cfg=cfg)
|
||||
)
|
||||
total_out_channels += _out_channels[-1]
|
||||
|
||||
self.num_centers = num_centers
|
||||
self.out_channels = total_out_channels
|
||||
self.groupers = nn.ModuleList(groupers)
|
||||
self.mlps = nn.ModuleList(mlps)
|
||||
|
||||
def forward(self, inputs):
|
||||
features = inputs[0]
|
||||
coords = inputs[1] # B3N
|
||||
style = inputs[3]
|
||||
if coords.shape[1] > 3:
|
||||
coords = coords[:,:3]
|
||||
|
||||
centers_coords = F.furthest_point_sample(coords, self.num_centers)
|
||||
# centers_coords: B,D,N
|
||||
S = centers_coords.shape[-1]
|
||||
time_emb = inputs[2]
|
||||
time_emb = time_emb[:,:,:S] if \
|
||||
time_emb is not None and type(time_emb) is not dict \
|
||||
else time_emb
|
||||
|
||||
features_list = []
|
||||
c = 0
|
||||
for grouper, mlp in zip(self.groupers, self.mlps):
|
||||
c += 1
|
||||
grouper_output = grouper(coords, centers_coords, features )
|
||||
features_list.append(
|
||||
mlp(grouper_output, style
|
||||
).max(dim=-1).values
|
||||
)
|
||||
|
||||
if len(features_list) > 1:
|
||||
return torch.cat(features_list, dim=1), centers_coords, time_emb, style
|
||||
else:
|
||||
return features_list[0], centers_coords, time_emb, style
|
||||
|
||||
def extra_repr(self):
|
||||
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
||||
|
||||
|
||||
class PointNetFPModule(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, cfg={}):
|
||||
super().__init__()
|
||||
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1, cfg=cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
if len(inputs) == 5:
|
||||
points_coords, centers_coords, centers_features, time_emb, style = inputs
|
||||
points_features = None
|
||||
elif len(inputs) == 6:
|
||||
points_coords, centers_coords, centers_features, points_features, time_emb, style = inputs
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
||||
if points_features is not None:
|
||||
interpolated_features = torch.cat(
|
||||
[interpolated_features, points_features], dim=1
|
||||
)
|
||||
if time_emb is not None:
|
||||
B,D,S = time_emb.shape
|
||||
N = points_coords.shape[-1]
|
||||
time_emb = time_emb[:,:,0:1].expand(-1,-1,N)
|
||||
return self.mlp(interpolated_features, style), points_coords, time_emb, style
|
||||
|
||||
def _linear_gn_relu(in_channels, out_channels):
|
||||
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
|
||||
|
||||
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1, cfg={}):
|
||||
r = width_multiplier
|
||||
|
||||
if dim == 1:
|
||||
block = _linear_gn_relu
|
||||
else:
|
||||
block = SharedMLP
|
||||
if not isinstance(out_channels, (list, tuple)):
|
||||
out_channels = [out_channels]
|
||||
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
|
||||
return nn.Sequential(), in_channels, in_channels
|
||||
|
||||
layers = []
|
||||
for oc in out_channels[:-1]:
|
||||
if oc < 1:
|
||||
layers.append(nn.Dropout(oc))
|
||||
else:
|
||||
oc = int(r * oc)
|
||||
layers.append(block(in_channels, oc, cfg=cfg))
|
||||
in_channels = oc
|
||||
if dim == 1:
|
||||
if classifier:
|
||||
layers.append(nn.Linear(in_channels, out_channels[-1]))
|
||||
else:
|
||||
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
|
||||
else:
|
||||
if classifier:
|
||||
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
|
||||
else:
|
||||
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
|
||||
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
|
||||
|
||||
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels,
|
||||
input_dim=3,
|
||||
embed_dim=64, use_att=False, force_att=0,
|
||||
dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True,
|
||||
cfg={}):
|
||||
"""
|
||||
Returns:
|
||||
in_channels: the last output channels of the sa blocks
|
||||
"""
|
||||
assert(len(cfg) > 0), cfg
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
in_channels = extra_feature_channels + input_dim
|
||||
|
||||
sa_layers, sa_in_channels = [], []
|
||||
c = 0
|
||||
num_centers = None
|
||||
for conv_configs, sa_configs in sa_blocks:
|
||||
k = 0
|
||||
sa_in_channels.append(in_channels)
|
||||
sa_blocks = []
|
||||
if conv_configs is not None:
|
||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0)
|
||||
if voxel_resolution is None:
|
||||
block = SharedMLP
|
||||
else:
|
||||
block = functools.partial(
|
||||
PVConv, kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se, # with_se_relu=True,
|
||||
normalize=normalize, eps=eps, verbose=verbose, cfg=cfg)
|
||||
|
||||
if c == 0:
|
||||
sa_blocks.append(block(in_channels, out_channels, cfg=cfg))
|
||||
elif k ==0:
|
||||
sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels, cfg=cfg))
|
||||
in_channels = out_channels
|
||||
k += 1
|
||||
extra_feature_channels = in_channels
|
||||
if sa_configs is not None:
|
||||
num_centers, radius, num_neighbors, out_channels = sa_configs
|
||||
_out_channels = []
|
||||
for oc in out_channels:
|
||||
if isinstance(oc, (list, tuple)):
|
||||
_out_channels.append([int(r * _oc) for _oc in oc])
|
||||
else:
|
||||
_out_channels.append(int(r * oc))
|
||||
out_channels = _out_channels
|
||||
if num_centers is None:
|
||||
block = PointNetAModule
|
||||
else:
|
||||
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
|
||||
num_neighbors=num_neighbors)
|
||||
sa_blocks.append(block(cfg=cfg,
|
||||
in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ),
|
||||
out_channels=out_channels,
|
||||
include_coordinates=True))
|
||||
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
|
||||
c += 1
|
||||
|
||||
if len(sa_blocks) == 1:
|
||||
sa_layers.append(sa_blocks[0])
|
||||
else:
|
||||
sa_layers.append(nn.Sequential(*sa_blocks))
|
||||
|
||||
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
|
||||
|
||||
|
||||
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
|
||||
dropout=0.1, has_temb=1,
|
||||
with_se=False, normalize=True, eps=0,
|
||||
width_multiplier=1, voxel_resolution_multiplier=1,
|
||||
verbose=True, cfg={}):
|
||||
assert(len(cfg) > 0), cfg
|
||||
r, vr = width_multiplier, voxel_resolution_multiplier
|
||||
|
||||
fp_layers = []
|
||||
c = 0
|
||||
|
||||
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
|
||||
fp_blocks = []
|
||||
out_channels = tuple(int(r * oc) for oc in fp_configs)
|
||||
fp_blocks.append(
|
||||
PointNetFPModule(
|
||||
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb,
|
||||
out_channels=out_channels,
|
||||
cfg=cfg)
|
||||
)
|
||||
in_channels = out_channels[-1]
|
||||
|
||||
if conv_configs is not None:
|
||||
out_channels, num_blocks, voxel_resolution = conv_configs
|
||||
out_channels = int(r * out_channels)
|
||||
for p in range(num_blocks):
|
||||
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
|
||||
if voxel_resolution is None:
|
||||
block = functools.partial(SharedMLP, cfg=cfg)
|
||||
else:
|
||||
block = functools.partial(PVConv, kernel_size=3,
|
||||
resolution=int(vr * voxel_resolution), attention=attention,
|
||||
dropout=dropout,
|
||||
with_se=with_se, # with_se_relu=True,
|
||||
normalize=normalize, eps=eps,
|
||||
verbose=verbose,
|
||||
cfg=cfg)
|
||||
|
||||
fp_blocks.append(block(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
if len(fp_blocks) == 1:
|
||||
fp_layers.append(fp_blocks[0])
|
||||
else:
|
||||
fp_layers.append(nn.Sequential(*fp_blocks))
|
||||
|
||||
c += 1
|
||||
|
||||
return fp_layers, in_channels
|
||||
|
230
models/score_sde/resnet.py
Normal file
230
models/score_sde/resnet.py
Normal file
|
@ -0,0 +1,230 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
""" implement the gloabl prior for LION
|
||||
"""
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
import functools
|
||||
import torch
|
||||
from ..utils import init_temb_fun, mask_inactive_variables
|
||||
|
||||
class SE(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(channel, channel // reduction, 1, 1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // reduction, channel, 1, 1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs * self.fc(inputs)
|
||||
|
||||
class ResBlockSEClip(nn.Module):
|
||||
"""
|
||||
fixed the conv0 not used error in ResBlockSE
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.non_linearity = nn.ReLU(inplace=True)
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(input_dim*2, output_dim, 1, 1)
|
||||
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
|
||||
in_ch = self.output_dim
|
||||
self.SE = SE(in_ch)
|
||||
def forward(self, x, t):
|
||||
## logger.info('x: {}, t: {}, input_dim={}', x.shape, t.shape, self.input_dim)
|
||||
clip_feat = t[:, self.input_dim:].contiguous()
|
||||
t = t[:,:self.input_dim].contiguous()
|
||||
output = x + t
|
||||
output = torch.cat([output, clip_feat], dim=1).contiguous()
|
||||
output = self.conv1(output)
|
||||
output = self.non_linearity(output)
|
||||
output = self.conv2(output)
|
||||
output = self.non_linearity(output)
|
||||
output = self.SE(output)
|
||||
shortcut = x
|
||||
return shortcut + output
|
||||
def __repr__(self):
|
||||
return "ResBlockSEClip(%d, %d)"%(self.input_dim, self.output_dim)
|
||||
|
||||
|
||||
|
||||
class ResBlockSEDrop(nn.Module):
|
||||
"""
|
||||
fixed the conv0 not used error in ResBlockSE
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, dropout):
|
||||
super().__init__()
|
||||
self.non_linearity = nn.ReLU(inplace=True)
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
|
||||
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
|
||||
in_ch = self.output_dim
|
||||
self.SE = SE(in_ch)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_ratio = dropout
|
||||
|
||||
def forward(self, x, t):
|
||||
output = x + t
|
||||
output = self.conv1(output)
|
||||
output = self.non_linearity(output)
|
||||
output = self.dropout(output)
|
||||
output = self.conv2(output)
|
||||
output = self.non_linearity(output)
|
||||
output = self.SE(output)
|
||||
shortcut = x
|
||||
return shortcut + output
|
||||
|
||||
def __repr__(self):
|
||||
return "ResBlockSE_withdropout(%d, %d, drop=%f)" % (
|
||||
self.input_dim, self.output_dim, self.dropout_ratio)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
# resample=None, act=nn.ELU(),
|
||||
# normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
|
||||
super().__init__()
|
||||
self.non_linearity = nn.ELU()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
|
||||
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
|
||||
in_ch = self.output_dim
|
||||
self.normalize1 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
||||
num_channels=in_ch, eps=1e-6)
|
||||
self.normalize2 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
||||
num_channels=in_ch, eps=1e-6)
|
||||
|
||||
def forward(self, x, t):
|
||||
x = x + t
|
||||
output = self.conv1(x)
|
||||
output = self.normalize1(output)
|
||||
output = self.non_linearity(output)
|
||||
output = self.conv2(output)
|
||||
output = self.normalize2(output)
|
||||
output = self.non_linearity(output)
|
||||
shortcut = x
|
||||
return shortcut + output
|
||||
|
||||
def __repr__(self):
|
||||
return "ResBlock(%d, %d)" % (self.input_dim, self.output_dim)
|
||||
|
||||
|
||||
class Prior(nn.Module):
|
||||
building_block = ResBlock
|
||||
|
||||
def __init__(self, args, num_input_channels, *oargs, **kwargs):
|
||||
super().__init__()
|
||||
# args: cfg.sde
|
||||
# oargs: other argument: the global argument
|
||||
self.condition_input = kwargs.get('condition_input', False)
|
||||
self.cfg = oargs[0]
|
||||
self.clip_forge_enable = self.cfg.clipforge.enable # kwargs.get('clipforge.enable', 0)
|
||||
|
||||
logger.info('[Build Resnet Prior] Has condition input: {}; clipforge {}; '
|
||||
'learn_mixing_logit={}, ', self.condition_input,
|
||||
self.clip_forge_enable, args.learn_mixing_logit)
|
||||
|
||||
self.act = act = nn.SiLU()
|
||||
self.num_scales = args.num_scales_dae
|
||||
self.num_input_channels = num_input_channels
|
||||
|
||||
self.nf = nf = args.num_channels_dae
|
||||
num_cell_per_scale_dae = args.num_cell_per_scale_dae if 'num_cell_per_scale_dae' not in kwargs else kwargs[
|
||||
'num_cell_per_scale_dae']
|
||||
|
||||
# take clip feature as input
|
||||
if self.clip_forge_enable:
|
||||
self.clip_feat_mapping = nn.Conv1d(self.cfg.clipforge.feat_dim, self.nf, 1)
|
||||
|
||||
# mixed_prediction #
|
||||
self.mixed_prediction = args.mixed_prediction # This enables mixed prediction
|
||||
if self.mixed_prediction:
|
||||
logger.info('init-mixing_logit = {}, after sigmoid = {}',
|
||||
args.mixing_logit_init, torch.sigmoid(torch.tensor(args.mixing_logit_init)))
|
||||
assert(args.mixing_logit_init), f'require learning'
|
||||
# if not args.learn_mixing_logit and args.hypara_mixing_logit:
|
||||
# # not learn, treat it as hyparameters
|
||||
# init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
|
||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
||||
# self.is_active = None
|
||||
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
||||
# init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
||||
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
||||
# self.is_active = None
|
||||
# else:
|
||||
if True:
|
||||
init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
|
||||
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
|
||||
self.is_active = None
|
||||
else: # no mixing_logit
|
||||
self.mixing_logit = None
|
||||
self.is_active = None
|
||||
|
||||
self.embedding_dim = args.embedding_dim
|
||||
self.embedding_dim_mult = 4
|
||||
self.temb_fun = init_temb_fun(args.embedding_type, args.embedding_scale, args.embedding_dim)
|
||||
logger.info('[temb_fun] embedding_type={}, embedding_scale={}, embedding_dim={}',
|
||||
args.embedding_type, args.embedding_scale, args.embedding_dim)
|
||||
# exit()
|
||||
modules = []
|
||||
modules.append(nn.Conv2d(self.embedding_dim, self.embedding_dim * 4, 1, 1))
|
||||
modules.append(nn.Conv2d(self.embedding_dim * 4, nf, 1, 1))
|
||||
self.temb_layer = nn.Sequential(*modules)
|
||||
|
||||
modules = []
|
||||
input_channels = num_input_channels
|
||||
self.input_layer = nn.Conv2d(input_channels, nf, 1, 1)
|
||||
in_ch = nf
|
||||
for i_block in range(args.num_cell_per_scale_dae):
|
||||
modules.append(self.building_block(nf, nf))
|
||||
self.output_layer = nn.Conv2d(nf, input_channels, 1, 1)
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
# time embedding
|
||||
if t.dim() == 0:
|
||||
t = t.expand(1)
|
||||
temb = self.temb_fun(t)[:, :, None, None] # make it 4d
|
||||
temb = self.temb_layer(temb)
|
||||
|
||||
if self.clip_forge_enable:
|
||||
clip_feat = kwargs['clip_feat']
|
||||
clip_feat = self.clip_feat_mapping(clip_feat[:, :, None])[:, :, :, None] # B,D -> BD1->B,D,1,1
|
||||
if temb.shape[0] == 1 and temb.shape[0] < clip_feat.shape[0]:
|
||||
temb = temb.expand(clip_feat.shape[0], -1, -1, -1)
|
||||
temb = torch.cat([temb, clip_feat], dim=1) # add to temb feature
|
||||
# mask out inactive variables
|
||||
if self.mixed_prediction and self.is_active is not None:
|
||||
x = mask_inactive_variables(x, self.is_active)
|
||||
x = self.input_layer(x)
|
||||
for layer in self.all_modules:
|
||||
enc_input = x
|
||||
x = layer(enc_input, temb)
|
||||
|
||||
h = self.output_layer(x)
|
||||
return h
|
||||
|
||||
|
||||
class PriorSEDrop(Prior):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.building_block = functools.partial(ResBlockSEDrop, dropout=args[0].dropout)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
class PriorSEClip(Prior):
|
||||
building_block = ResBlockSEClip
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
54
models/shapelatent_modules.py
Normal file
54
models/shapelatent_modules.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from .pvcnn2 import create_pointnet2_sa_components
|
||||
# implement the global encoder for VAE model
|
||||
|
||||
class PointNetPlusEncoder(nn.Module):
|
||||
sa_blocks = [
|
||||
[[32, 2, 32], [1024, 0.1, 32, [32, 32]]],
|
||||
[[32, 1, 16], [256, 0.2, 32, [32, 64]]]
|
||||
]
|
||||
force_att = 0 # add attention to all layers
|
||||
def __init__(self, zdim, input_dim, extra_feature_channels=0, args={}):
|
||||
super().__init__()
|
||||
sa_blocks = self.sa_blocks
|
||||
layers, sa_in_channels, channels_sa_features, _ = \
|
||||
create_pointnet2_sa_components(sa_blocks,
|
||||
extra_feature_channels, input_dim=input_dim,
|
||||
embed_dim=0, force_att=self.force_att,
|
||||
use_att=True, with_se=True)
|
||||
self.mlp = nn.Linear(channels_sa_features, zdim*2)
|
||||
self.zdim = zdim
|
||||
logger.info('[Encoder] zdim={}, out_sigma={}; force_att: {}', zdim, True, self.force_att)
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.voxel_dim = [n[1][-1][-1] for n in self.sa_blocks]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: B,N,3
|
||||
Returns:
|
||||
mu, sigma: B,D
|
||||
"""
|
||||
output = {}
|
||||
x = x.transpose(1, 2) # B,3,N
|
||||
xyz = x ## x[:,:3,:]
|
||||
features = x
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
features, xyz, _ = layer( (features, xyz, None) )
|
||||
# features: B,D,N; xyz: B,3,N
|
||||
|
||||
features = features.max(-1)[0]
|
||||
features = self.mlp(features)
|
||||
mu_1d, sigma_1d = features[:, :self.zdim], features[:, self.zdim:]
|
||||
output.update({'mu_1d': mu_1d, 'sigma_1d': sigma_1d})
|
||||
return output
|
||||
|
||||
|
52
models/utils.py
Normal file
52
models/utils.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
|
||||
def mask_inactive_variables(x, is_active):
|
||||
x = x * is_active
|
||||
return x
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim, scale):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
assert len(timesteps.shape) == 1
|
||||
timesteps = timesteps * self.scale
|
||||
half_dim = self.embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class RandomFourierEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim, scale):
|
||||
super(RandomFourierEmbedding, self).__init__()
|
||||
self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False)
|
||||
|
||||
def forward(self, timesteps):
|
||||
emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359)
|
||||
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
|
||||
|
||||
def init_temb_fun(embedding_type, embedding_scale, embedding_dim):
|
||||
if embedding_type == 'positional':
|
||||
temb_fun = PositionalEmbedding(embedding_dim, embedding_scale)
|
||||
elif embedding_type == 'fourier':
|
||||
temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return temb_fun
|
339
models/vae_adain.py
Normal file
339
models/vae_adain.py
Normal file
|
@ -0,0 +1,339 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
import importlib
|
||||
import torch.nn as nn
|
||||
from .distributions import Normal
|
||||
from utils.model_helper import import_model
|
||||
from utils.model_helper import loss_fn
|
||||
from utils import utils as helper
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.num_total_iter = 0
|
||||
self.args = args
|
||||
self.input_dim = args.ddpm.input_dim
|
||||
latent_dim = args.shapelatent.latent_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.kl_weight = args.shapelatent.kl_weight
|
||||
|
||||
self.num_points = args.data.tr_max_sample_points
|
||||
# ---- global ---- #
|
||||
# build encoder
|
||||
self.style_encoder = import_model(args.latent_pts.style_encoder)(
|
||||
zdim=args.latent_pts.style_dim,
|
||||
input_dim=self.input_dim,
|
||||
args=args)
|
||||
if len(args.latent_pts.style_mlp):
|
||||
self.style_mlp = import_model(args.latent_pts.style_mlp)(args)
|
||||
else:
|
||||
self.style_mlp = None
|
||||
|
||||
self.encoder = import_model(args.shapelatent.encoder_type)(
|
||||
zdim=latent_dim,
|
||||
input_dim=self.input_dim,
|
||||
args=args)
|
||||
|
||||
# build decoder
|
||||
self.decoder = import_model(args.shapelatent.decoder_type)(
|
||||
context_dim=latent_dim,
|
||||
point_dim=args.ddpm.input_dim,
|
||||
args=args)
|
||||
logger.info('[Build Model] style_encoder: {}, encoder: {}, decoder: {}',
|
||||
args.latent_pts.style_encoder,
|
||||
args.shapelatent.encoder_type,
|
||||
args.shapelatent.decoder_type)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x, class_label=None):
|
||||
batch_size, _, point_dim = x.size()
|
||||
assert(x.shape[2] == self.input_dim), f'expect input in ' \
|
||||
f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}'
|
||||
x_0_target = x
|
||||
latent_list = []
|
||||
all_eps = []
|
||||
all_log_q = []
|
||||
if self.args.data.cond_on_cat:
|
||||
assert(class_label is not None), f'require class label input for cond on cat'
|
||||
cls_emb = self.class_embedding(class_label)
|
||||
enc_input = x, cls_emb
|
||||
else:
|
||||
enc_input = x
|
||||
|
||||
# ---- global style encoder ---- #
|
||||
z = self.style_encoder(enc_input)
|
||||
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
||||
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
||||
z_global = dist.sample()[0]
|
||||
all_eps.append(z_global)
|
||||
all_log_q.append(dist.log_p(z_global))
|
||||
latent_list.append( [z_global, z_mu, z_sigma] )
|
||||
|
||||
# ---- original encoder ---- #
|
||||
style = z_global # torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
|
||||
style = self.style_mlp(style) if self.style_mlp is not None else style
|
||||
z = self.encoder([x, style])
|
||||
z_mu, z_sigma = z['mu_1d'], z['sigma_1d']
|
||||
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
|
||||
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
||||
z_local = dist.sample()[0]
|
||||
all_eps.append(z_local)
|
||||
all_log_q.append(dist.log_p(z_local))
|
||||
latent_list.append( [z_local, z_mu, z_sigma] )
|
||||
all_eps = self.compose_eps(all_eps)
|
||||
if self.args.data.cond_on_cat:
|
||||
return all_eps, all_log_q, latent_list, cls_emb
|
||||
else:
|
||||
return all_eps, all_log_q, latent_list
|
||||
|
||||
def compose_eps(self, all_eps):
|
||||
return torch.cat(all_eps, dim=1) # style: [B,D1], latent pts: [B,ND2]
|
||||
|
||||
def decompose_eps(self, all_eps):
|
||||
eps_style = all_eps[:,:self.args.latent_pts.style_dim]
|
||||
eps_local = all_eps[:,self.args.latent_pts.style_dim:]
|
||||
return [eps_style, eps_local]
|
||||
|
||||
def encode_global(self, x, class_label=None):
|
||||
|
||||
batch_size, N, point_dim = x.size()
|
||||
if self.args.data.cond_on_cat:
|
||||
assert(class_label is not None), f'require class label input for cond on cat'
|
||||
cls_emb = self.class_embedding(class_label)
|
||||
enc_input = x, cls_emb
|
||||
else:
|
||||
enc_input = x
|
||||
|
||||
z = self.style_encoder(enc_input)
|
||||
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
||||
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
||||
return dist
|
||||
|
||||
def global2style(self, style): ##, cls_emb=None):
|
||||
Ndim = len(style.shape)
|
||||
if Ndim == 4:
|
||||
style = style.squeeze(-1).squeeze(-1)
|
||||
style = self.style_mlp(style) if self.style_mlp is not None else style
|
||||
if Ndim == 4:
|
||||
style = style.unsqueeze(-1).unsqueeze(-1)
|
||||
return style
|
||||
|
||||
def encode_local(self, x, style):
|
||||
# ---- original encoder ---- #
|
||||
z = self.encoder([x, style])
|
||||
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
||||
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
|
||||
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
||||
return dist
|
||||
|
||||
def recont(self, x, target=None, class_label=None, cls_emb=None):
|
||||
batch_size, N, point_dim = x.size()
|
||||
assert(x.shape[2] == self.input_dim), f'expect input in ' \
|
||||
f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}'
|
||||
x_0_target = x if target is None else target
|
||||
latent_list = []
|
||||
all_eps = []
|
||||
all_log_q = []
|
||||
|
||||
# ---- global style encoder ---- #
|
||||
if self.args.data.cond_on_cat:
|
||||
if class_label is not None:
|
||||
assert(class_label is not None)
|
||||
cls_emb = self.class_embedding(class_label)
|
||||
else:
|
||||
assert(cls_emb is not None)
|
||||
|
||||
enc_input = x, cls_emb
|
||||
else:
|
||||
enc_input = x
|
||||
z = self.style_encoder(enc_input)
|
||||
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
||||
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
||||
|
||||
z_global = dist.sample()[0]
|
||||
all_eps.append(z_global)
|
||||
all_log_q.append(dist.log_p(z_global))
|
||||
latent_list.append( [z_global, z_mu, z_sigma] )
|
||||
|
||||
# ---- original encoder ---- #
|
||||
style = torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
|
||||
style = self.style_mlp(style) if self.style_mlp is not None else style
|
||||
z = self.encoder([x, style])
|
||||
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
||||
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
|
||||
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
||||
z_local = dist.sample()[0]
|
||||
all_eps.append(z_local)
|
||||
all_log_q.append(dist.log_p(z_local))
|
||||
latent_list.append( [z_local, z_mu, z_sigma] )
|
||||
|
||||
# ---- decoder ---- #
|
||||
x_0_pred = self.decoder(None, beta=None, context=z_local, style=style) # (B,ncenter,3)
|
||||
|
||||
make_4d = lambda x: x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1)
|
||||
all_eps = [make_4d(e) for e in all_eps]
|
||||
all_log_q = [make_4d(e) for e in all_log_q]
|
||||
|
||||
output = {
|
||||
'all_eps': all_eps,
|
||||
'all_log_q': all_log_q,
|
||||
'latent_list': latent_list,
|
||||
'x_0_pred':x_0_pred,
|
||||
'x_0_target': x_0_target,
|
||||
'x_t': torch.zeros_like(x_0_target),
|
||||
't': torch.zeros(batch_size),
|
||||
'x_0': x_0_target
|
||||
}
|
||||
output['hist/global_var'] = latent_list[0][2].exp()
|
||||
|
||||
if 'LatentPoint' in self.args.shapelatent.decoder_type:
|
||||
latent_shape = [batch_size, -1, self.latent_dim + self.input_dim]
|
||||
if 'Hir' in self.args.shapelatent.decoder_type:
|
||||
latent_pts = z_local[:,:-self.args.latent_pts.latent_dim_ext[0]].view(*latent_shape)[:,:,:3].contiguous().clone()
|
||||
else:
|
||||
latent_pts = z_local.view(*latent_shape)[:,:,:self.input_dim].contiguous().clone()
|
||||
|
||||
output['vis/latent_pts'] = latent_pts.detach().cpu().view(batch_size,
|
||||
-1, self.input_dim) # B,N,3
|
||||
output['final_pred'] = output['x_0_pred']
|
||||
return output
|
||||
|
||||
def get_loss(self, x, writer=None, it=None, ## weight_loss_1=1,
|
||||
noisy_input=None, class_label=None, **kwargs):
|
||||
"""
|
||||
shapelatent z ~ q(z|x_0)
|
||||
and x_t ~ q(x_t|x_0, t), t ~ Uniform(T)
|
||||
forward and get x_{t-1} ~ p(x_{t-1} | x_t, z)
|
||||
Args:
|
||||
x: Input point clouds, (B, N, d).
|
||||
"""
|
||||
## kl_weight = self.kl_weight
|
||||
if self.args.trainer.anneal_kl and self.num_total_iter > 0:
|
||||
global_step = it
|
||||
kl_weight = helper.kl_coeff(step=global_step,
|
||||
total_step=self.args.sde.kl_anneal_portion_vada * self.num_total_iter,
|
||||
constant_step=self.args.sde.kl_const_portion_vada * self.num_total_iter,
|
||||
min_kl_coeff=self.args.sde.kl_const_coeff_vada,
|
||||
max_kl_coeff=self.args.sde.kl_max_coeff_vada)
|
||||
else:
|
||||
kl_weight = self.kl_weight
|
||||
|
||||
batch_size = x.shape[0]
|
||||
# CHECKDIM(x, 2, self.input_dim)
|
||||
assert(x.shape[2] == self.input_dim)
|
||||
|
||||
inputs = noisy_input if noisy_input is not None else x
|
||||
output = self.recont(inputs, target=x, class_label=class_label)
|
||||
|
||||
x_0_pred, x_0_target = output['x_0_pred'], output['x_0_target']
|
||||
loss_0 = loss_fn(x_0_pred, x_0_target, self.args.ddpm.loss_type,
|
||||
self.input_dim, batch_size).mean()
|
||||
rec_loss = loss_0
|
||||
output['print/loss_0'] = loss_0
|
||||
output['rec_loss'] = rec_loss
|
||||
|
||||
# Loss
|
||||
## z_global, z_sigma, z_mu = output['z_global'], output['z_sigma'], output['z_mu']
|
||||
kl_term_list = []
|
||||
weighted_kl_terms = []
|
||||
for pairs_id, pairs in enumerate(output['latent_list']):
|
||||
cz, cmu, csigma = pairs
|
||||
log_sigma = csigma
|
||||
kl_term_close = (0.5*log_sigma.exp()**2 +
|
||||
0.5*cmu**2 - log_sigma - 0.5).view(
|
||||
batch_size, -1)
|
||||
if 'LatentPoint' in self.args.shapelatent.decoder_type and 'Hir' not in self.args.shapelatent.decoder_type:
|
||||
if pairs_id == 1:
|
||||
latent_shape = [batch_size, -1, self.latent_dim + self.input_dim]
|
||||
kl_pt = kl_term_close.view(*latent_shape)[:,:,:self.input_dim]
|
||||
kl_feat = kl_term_close.view(*latent_shape)[:,:,self.input_dim:]
|
||||
weighted_kl_terms.append(kl_pt.sum(2).sum(1) * self.args.latent_pts.weight_kl_pt)
|
||||
weighted_kl_terms.append(kl_feat.sum(2).sum(1) * self.args.latent_pts.weight_kl_feat)
|
||||
|
||||
output['print/kl_pt%d'%pairs_id] = kl_pt.sum(2).sum(1)
|
||||
output['print/kl_feat%d'%pairs_id] = kl_feat.sum(2).sum(1)
|
||||
|
||||
output['print/z_var_pt%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,:self.input_dim]
|
||||
).exp()**2
|
||||
output['print/z_var_feat%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,self.input_dim:]
|
||||
).exp()**2
|
||||
output['print/z_mean_feat%d'%pairs_id] = cmu.view(*latent_shape)[:,:,self.input_dim:].mean()
|
||||
elif pairs_id == 0:
|
||||
kl_style = kl_term_close
|
||||
weighted_kl_terms.append(kl_style.sum(-1) * self.args.latent_pts.weight_kl_glb)
|
||||
|
||||
output['print/kl_glb%d'%pairs_id] = kl_style.sum(-1)
|
||||
output['print/z_var_glb%d'%pairs_id] = (log_sigma).exp()**2
|
||||
|
||||
kl_term_close = kl_term_close.sum(-1)
|
||||
kl_term_list.append(kl_term_close)
|
||||
output['print/kl_%d'%pairs_id] = kl_term_close
|
||||
output['print/z_mean_%d'%pairs_id] = cmu.mean()
|
||||
output['print/z_mag_%d'%pairs_id] = cmu.abs().max()
|
||||
# logger.info('log_sigma: {}, mean: {}', log_sigma.shape, (log_sigma.exp()**2).mean())
|
||||
output['print/z_var_%d'%pairs_id] = (log_sigma).exp()**2
|
||||
output['print/z_logsigma_%d'%pairs_id] = log_sigma
|
||||
output['print/kl_weight'] = kl_weight
|
||||
|
||||
|
||||
loss_recons = rec_loss
|
||||
if len(weighted_kl_terms) > 0:
|
||||
kl = kl_weight * sum(weighted_kl_terms)
|
||||
else:
|
||||
kl = kl_weight * sum(kl_term_list)
|
||||
loss = kl + loss_recons * self.args.weight_recont
|
||||
output['msg/kl'] = kl
|
||||
output['msg/rec'] = loss_recons
|
||||
output['loss'] = loss
|
||||
return output
|
||||
|
||||
def pz(self, w):
|
||||
return w
|
||||
|
||||
def sample(self, num_samples=10, temp=None, decomposed_eps=[],
|
||||
enable_autocast=False, device_str='cuda', cls_emb=None):
|
||||
""" currently not support the samples of local level
|
||||
Return:
|
||||
model_output: [B,N,D]
|
||||
"""
|
||||
batch_size = num_samples
|
||||
center_emd = None
|
||||
if 'LatentPoint' in self.args.shapelatent.decoder_type:
|
||||
# Latent Point Model: latent shape; B; ND
|
||||
latent_shape = (num_samples, self.num_points*(self.latent_dim+self.input_dim))
|
||||
style_latent_shape = (num_samples, self.args.latent_pts.style_dim)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if len(decomposed_eps) == 0:
|
||||
z_local = torch.zeros(*latent_shape).to(
|
||||
torch.device(device_str)).normal_()
|
||||
z_global = torch.zeros(*style_latent_shape).to(
|
||||
torch.device(device_str)).normal_()
|
||||
else:
|
||||
z_global = decomposed_eps[0]
|
||||
z_local = decomposed_eps[1]
|
||||
|
||||
z_local = z_local.view(*latent_shape)
|
||||
z_global = z_global.view(style_latent_shape)
|
||||
|
||||
style = z_global
|
||||
style = self.style_mlp(style) if self.style_mlp is not None else style
|
||||
x_0_pred = self.decoder(None, beta=None,
|
||||
context=z_local, style=z_global) # (B,ncenter,3)
|
||||
## CHECKSIZE(x_0_pred, (batch_size,self.num_points,[3,6]))
|
||||
return x_0_pred
|
||||
|
||||
def latent_shape(self):
|
||||
return [
|
||||
[self.args.latent_pts.style_dim, 1, 1],
|
||||
[self.num_points*(self.latent_dim+self.input_dim),1,1]
|
||||
]
|
43
script/compute_score.py
Normal file
43
script/compute_score.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
from utils.eval_helper import compute_score
|
||||
# samples = sys.argv[1]
|
||||
# ref = sys.argv[2]
|
||||
|
||||
samples = './lion_ckpt/unconditional/car/samples.pt'
|
||||
ref = './datasets/test_data/ref_val_car.pt'
|
||||
compute_score(samples, ref_name=ref)
|
||||
"""
|
||||
will get:
|
||||
[Test] MinMatDis | CD 0.000913 | EMD 0.007523
|
||||
[Test] Coverage | CD 0.500000 | EMD 0.565341
|
||||
[Test] 1NN-Accur | CD 0.534091 | EMD 0.511364
|
||||
[Test] JsnShnDis | 0.009229
|
||||
"""
|
||||
|
||||
samples = './lion_ckpt/unconditional/chair/samples.pt'
|
||||
ref = './datasets/test_data/ref_val_chair.pt'
|
||||
compute_score(samples, ref_name=ref)
|
||||
"""
|
||||
[Test] MinMatDis | CD 0.002643 | EMD 0.015516
|
||||
[Test] Coverage | CD 0.489426 | EMD 0.521148
|
||||
[Test] 1NN-Accur | CD 0.537009 | EMD 0.523414
|
||||
[Test] JsnShnDis | 0.013535
|
||||
"""
|
||||
|
||||
samples = './lion_ckpt/unconditional/chair/samples.pt'
|
||||
ref = './datasets/test_data/ref_val_chair.pt'
|
||||
compute_score(samples, ref_name=ref)
|
||||
"""
|
||||
[Test] MinMatDis | CD 0.000221 | EMD 0.003706
|
||||
[Test] Coverage | CD 0.471605 | EMD 0.496296
|
||||
[Test] 1NN-Accur | CD 0.674074 | EMD 0.612346
|
||||
[Test] JsnShnDis | 0.060703
|
||||
"""
|
3
third_party/ChamferDistancePytorch/.gitignore
vendored
Normal file
3
third_party/ChamferDistancePytorch/.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
*__pycache__*
|
||||
/tmp
|
||||
tmp/*
|
21
third_party/ChamferDistancePytorch/LICENSE
vendored
Normal file
21
third_party/ChamferDistancePytorch/LICENSE
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2019 ThibaultGROUEIX
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
104
third_party/ChamferDistancePytorch/README.md
vendored
Executable file
104
third_party/ChamferDistancePytorch/README.md
vendored
Executable file
|
@ -0,0 +1,104 @@
|
|||
* adapted from https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
|
||||
|
||||
----------------------------------
|
||||
# Pytorch Chamfer Distance.
|
||||
|
||||
Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations.
|
||||
NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly.
|
||||
|
||||
- [x] F - Score
|
||||
|
||||
|
||||
|
||||
### CUDA VERSION
|
||||
|
||||
- [x] JIT compilation
|
||||
- [x] Supports multi-gpu
|
||||
- [x] 2D point clouds.
|
||||
- [x] 3D point clouds.
|
||||
- [x] 5D point clouds.
|
||||
- [x] Contiguous() safe.
|
||||
|
||||
|
||||
|
||||
### Python Version
|
||||
|
||||
- [x] Supports any dimension
|
||||
|
||||
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import torch, chamfer3D.dist_chamfer_3D, fscore
|
||||
chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
|
||||
points1 = torch.rand(32, 1000, 3).cuda()
|
||||
points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda()
|
||||
dist1, dist2, idx1, idx2 = chamLoss(points1, points2)
|
||||
f_score, precision, recall = fscore.fscore(dist1, dist2)
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Add it to your project as a submodule
|
||||
|
||||
```shell
|
||||
git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Benchmark: [forward + backward] pass
|
||||
- [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4
|
||||
- [x] p1 : 32 x 2000 x dim
|
||||
- [x] p2 : 32 x 1000 x dim
|
||||
|
||||
| *Timing (sec * 1000)* | 2D | 3D | 5D |
|
||||
| ---------- | -------- | ------- | ------- |
|
||||
| **Cuda Compiled** | **1.2** | 1.4 |1.8 |
|
||||
| **Cuda JIT** | 1.3 | **1.4** |**1.5** |
|
||||
| **Python** | 37 | 37 | 37 |
|
||||
|
||||
|
||||
| *Memory (MB)* | 2D | 3D | 5D |
|
||||
| ---------- | -------- | ------- | ------- |
|
||||
| **Cuda Compiled** | 529 | 529 | 549 |
|
||||
| **Cuda JIT** | **520** | **529** |**549** |
|
||||
| **Python** | 2495 | 2495 | 2495 |
|
||||
|
||||
|
||||
|
||||
### What is the chamfer distance ?
|
||||
|
||||
[Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning
|
||||
|
||||
|
||||
|
||||
### Aknowledgment
|
||||
|
||||
Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu).
|
||||
|
||||
JIT cool trick from [Christian Diller](https://github.com/chrdiller)
|
||||
|
||||
### Troubleshoot
|
||||
|
||||
- `Undefined symbol: Zxxxxxxxxxxxxxxxxx `:
|
||||
|
||||
--> Fix: Make sure to `import torch` before you `import chamfer`.
|
||||
--> Use pytorch.version >= 1.1.0
|
||||
|
||||
- [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167)
|
||||
|
||||
```shell
|
||||
wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
|
||||
sudo unzip ninja-linux.zip -d /usr/local/bin/
|
||||
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#### TODO:
|
||||
|
||||
* Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions
|
182
third_party/ChamferDistancePytorch/chamfer2D/chamfer2D.cu
vendored
Executable file
182
third_party/ChamferDistancePytorch/chamfer2D/chamfer2D.cu
vendored
Executable file
|
@ -0,0 +1,182 @@
|
|||
|
||||
#include <stdio.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
||||
|
||||
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
|
||||
const int batch=512;
|
||||
__shared__ float buf[batch*2];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int k2=0;k2<m;k2+=batch){
|
||||
int end_k=min(m,k2+batch)-k2;
|
||||
for (int j=threadIdx.x;j<end_k*2;j+=blockDim.x){
|
||||
buf[j]=xyz2[(i*m+k2)*2+j];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz[(i*n+j)*2+0];
|
||||
float y1=xyz[(i*n+j)*2+1];
|
||||
int best_i=0;
|
||||
float best=0;
|
||||
int end_ka=end_k-(end_k&2);
|
||||
if (end_ka==batch){
|
||||
for (int k=0;k<batch;k+=4){
|
||||
{
|
||||
float x2=buf[k*2+0]-x1;
|
||||
float y2=buf[k*2+1]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*2+2]-x1;
|
||||
float y2=buf[k*2+3]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*2+4]-x1;
|
||||
float y2=buf[k*2+5]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*2+6]-x1;
|
||||
float y2=buf[k*2+7]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for (int k=0;k<end_ka;k+=4){
|
||||
{
|
||||
float x2=buf[k*2+0]-x1;
|
||||
float y2=buf[k*2+1]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*2+2]-x1;
|
||||
float y2=buf[k*2+3]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*2+4]-x1;
|
||||
float y2=buf[k*2+5]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*2+6]-x1;
|
||||
float y2=buf[k*2+7]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k=end_ka;k<end_k;k++){
|
||||
float x2=buf[k*2+0]-x1;
|
||||
float y2=buf[k*2+1]-y1;
|
||||
float d=x2*x2+y2*y2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
if (k2==0 || result[(i*n+j)]>best){
|
||||
result[(i*n+j)]=best;
|
||||
result_i[(i*n+j)]=best_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
|
||||
}
|
||||
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz1[(i*n+j)*2+0];
|
||||
float y1=xyz1[(i*n+j)*2+1];
|
||||
int j2=idx1[i*n+j];
|
||||
float x2=xyz2[(i*m+j2)*2+0];
|
||||
float y2=xyz2[(i*m+j2)*2+1];
|
||||
float g=grad_dist1[i*n+j]*2;
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*2+0]),g*(x1-x2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*2+1]),g*(y1-y2));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*2+0]),-(g*(x1-x2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*2+1]),-(g*(y1-y2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
|
||||
// cudaMemset(grad_xyz1,0,b*n*3*4);
|
||||
// cudaMemset(grad_xyz2,0,b*m*3*4);
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
}
|
||||
|
33
third_party/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp
vendored
Executable file
33
third_party/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp
vendored
Executable file
|
@ -0,0 +1,33 @@
|
|||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
///TMP
|
||||
//#include "common.h"
|
||||
/// NOT TMP
|
||||
|
||||
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
|
||||
|
||||
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
|
||||
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
|
||||
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
|
||||
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
|
||||
}
|
80
third_party/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py
vendored
Normal file
80
third_party/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py
vendored
Normal file
|
@ -0,0 +1,80 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
chamfer_found = importlib.find_loader("chamfer_2D") is not None
|
||||
if not chamfer_found:
|
||||
## Cool trick from https://github.com/chrdiller
|
||||
print("Jitting Chamfer 2D")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
build_path = cur_path.replace('chamfer2D', 'tmp')
|
||||
os.makedirs(build_path, exist_ok=True)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_2D = load(name="chamfer_2D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
|
||||
], build_directory=build_path)
|
||||
print("Loaded JIT 2D CUDA chamfer distance")
|
||||
|
||||
else:
|
||||
import chamfer_2D
|
||||
print("Loaded compiled 2D CUDA chamfer distance")
|
||||
|
||||
# Chamfer's distance module @thibaultgroueix
|
||||
# GPU tensors only
|
||||
class chamfer_2DFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
batchsize, n, dim = xyz1.size()
|
||||
assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
_, m, dim = xyz2.size()
|
||||
assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
device = xyz1.device
|
||||
|
||||
device = xyz1.device
|
||||
|
||||
dist1 = torch.zeros(batchsize, n)
|
||||
dist2 = torch.zeros(batchsize, m)
|
||||
|
||||
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
||||
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
||||
|
||||
dist1 = dist1.to(device)
|
||||
dist2 = dist2.to(device)
|
||||
idx1 = idx1.to(device)
|
||||
idx2 = idx2.to(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
||||
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
||||
return dist1, dist2, idx1, idx2
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
|
||||
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
||||
graddist1 = graddist1.contiguous()
|
||||
graddist2 = graddist2.contiguous()
|
||||
device = graddist1.device
|
||||
|
||||
gradxyz1 = torch.zeros(xyz1.size())
|
||||
gradxyz2 = torch.zeros(xyz2.size())
|
||||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_2D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
class chamfer_2DDist(nn.Module):
|
||||
def __init__(self):
|
||||
super(chamfer_2DDist, self).__init__()
|
||||
|
||||
def forward(self, input1, input2):
|
||||
input1 = input1.contiguous()
|
||||
input2 = input2.contiguous()
|
||||
return chamfer_2DFunction.apply(input1, input2)
|
14
third_party/ChamferDistancePytorch/chamfer2D/setup.py
vendored
Executable file
14
third_party/ChamferDistancePytorch/chamfer2D/setup.py
vendored
Executable file
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_2D',
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_2D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
|
||||
]),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
196
third_party/ChamferDistancePytorch/chamfer3D/chamfer3D.cu
vendored
Executable file
196
third_party/ChamferDistancePytorch/chamfer3D/chamfer3D.cu
vendored
Executable file
|
@ -0,0 +1,196 @@
|
|||
|
||||
#include <stdio.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
||||
|
||||
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
|
||||
const int batch=512;
|
||||
__shared__ float buf[batch*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int k2=0;k2<m;k2+=batch){
|
||||
int end_k=min(m,k2+batch)-k2;
|
||||
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
|
||||
buf[j]=xyz2[(i*m+k2)*3+j];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz[(i*n+j)*3+0];
|
||||
float y1=xyz[(i*n+j)*3+1];
|
||||
float z1=xyz[(i*n+j)*3+2];
|
||||
int best_i=0;
|
||||
float best=0;
|
||||
int end_ka=end_k-(end_k&3);
|
||||
if (end_ka==batch){
|
||||
for (int k=0;k<batch;k+=4){
|
||||
{
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+3]-x1;
|
||||
float y2=buf[k*3+4]-y1;
|
||||
float z2=buf[k*3+5]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+6]-x1;
|
||||
float y2=buf[k*3+7]-y1;
|
||||
float z2=buf[k*3+8]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+9]-x1;
|
||||
float y2=buf[k*3+10]-y1;
|
||||
float z2=buf[k*3+11]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for (int k=0;k<end_ka;k+=4){
|
||||
{
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+3]-x1;
|
||||
float y2=buf[k*3+4]-y1;
|
||||
float z2=buf[k*3+5]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+6]-x1;
|
||||
float y2=buf[k*3+7]-y1;
|
||||
float z2=buf[k*3+8]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*3+9]-x1;
|
||||
float y2=buf[k*3+10]-y1;
|
||||
float z2=buf[k*3+11]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k=end_ka;k<end_k;k++){
|
||||
float x2=buf[k*3+0]-x1;
|
||||
float y2=buf[k*3+1]-y1;
|
||||
float z2=buf[k*3+2]-z1;
|
||||
float d=x2*x2+y2*y2+z2*z2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
if (k2==0 || result[(i*n+j)]>best){
|
||||
result[(i*n+j)]=best;
|
||||
result_i[(i*n+j)]=best_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
|
||||
}
|
||||
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz1[(i*n+j)*3+0];
|
||||
float y1=xyz1[(i*n+j)*3+1];
|
||||
float z1=xyz1[(i*n+j)*3+2];
|
||||
int j2=idx1[i*n+j];
|
||||
float x2=xyz2[(i*m+j2)*3+0];
|
||||
float y2=xyz2[(i*m+j2)*3+1];
|
||||
float z2=xyz2[(i*m+j2)*3+2];
|
||||
float g=grad_dist1[i*n+j]*2;
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
|
||||
// cudaMemset(grad_xyz1,0,b*n*3*4);
|
||||
// cudaMemset(grad_xyz2,0,b*m*3*4);
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
}
|
||||
|
33
third_party/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp
vendored
Executable file
33
third_party/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp
vendored
Executable file
|
@ -0,0 +1,33 @@
|
|||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
///TMP
|
||||
//#include "common.h"
|
||||
/// NOT TMP
|
||||
|
||||
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
|
||||
|
||||
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
|
||||
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
|
||||
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
|
||||
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
|
||||
}
|
133
third_party/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py
vendored
Normal file
133
third_party/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py
vendored
Normal file
|
@ -0,0 +1,133 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
build_path = cur_path.replace('chamfer3D', 'tmp')
|
||||
os.makedirs(build_path, exist_ok=True)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_3D = load(name="chamfer_3D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
||||
], build_directory=build_path)
|
||||
|
||||
#chamfer_found = importlib.find_loader("chamfer_3D") is not None
|
||||
#if not chamfer_found:
|
||||
# ## Cool trick from https://github.com/chrdiller
|
||||
# print("Jitting Chamfer 3D")
|
||||
# cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
# build_path = cur_path.replace('chamfer3D', 'tmp')
|
||||
# os.makedirs(build_path, exist_ok=True)
|
||||
#
|
||||
# from torch.utils.cpp_extension import load
|
||||
# chamfer_3D = load(name="chamfer_3D",
|
||||
# sources=[
|
||||
# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
||||
# ], build_directory=build_path)
|
||||
# print("Loaded JIT 3D CUDA chamfer distance")
|
||||
#
|
||||
#else:
|
||||
# import chamfer_3D
|
||||
# print("Loaded compiled 3D CUDA chamfer distance")
|
||||
|
||||
|
||||
# Chamfer's distance module @thibaultgroueix
|
||||
# GPU tensors only
|
||||
class chamfer_3DFunction(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
batchsize, n, dim = xyz1.size()
|
||||
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
_, m, dim = xyz2.size()
|
||||
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
device = xyz1.device
|
||||
|
||||
device = xyz1.device
|
||||
|
||||
dist1 = torch.zeros(batchsize, n)
|
||||
dist2 = torch.zeros(batchsize, m)
|
||||
|
||||
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
||||
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
||||
|
||||
dist1 = dist1.to(device)
|
||||
dist2 = dist2.to(device)
|
||||
idx1 = idx1.to(device)
|
||||
idx2 = idx2.to(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
||||
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
||||
return dist1, dist2, idx1, idx2
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
|
||||
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
||||
graddist1 = graddist1.contiguous()
|
||||
graddist2 = graddist2.contiguous()
|
||||
device = graddist1.device
|
||||
|
||||
gradxyz1 = torch.zeros(xyz1.size())
|
||||
gradxyz2 = torch.zeros(xyz2.size())
|
||||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_3D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
class chamfer_3DDist(nn.Module):
|
||||
def __init__(self):
|
||||
super(chamfer_3DDist, self).__init__()
|
||||
|
||||
def forward(self, input1, input2):
|
||||
input1 = input1.contiguous()
|
||||
input2 = input2.contiguous()
|
||||
return chamfer_3DFunction.apply(input1, input2)
|
||||
|
||||
|
||||
# Chamfer's distance module @thibaultgroueix
|
||||
# GPU tensors only
|
||||
class chamfer_3DFunction_noGrad(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
batchsize, n, dim = xyz1.size()
|
||||
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
_, m, dim = xyz2.size()
|
||||
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
device = xyz1.device
|
||||
|
||||
device = xyz1.device
|
||||
|
||||
dist1 = torch.zeros(batchsize, n)
|
||||
dist2 = torch.zeros(batchsize, m)
|
||||
|
||||
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
||||
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
||||
|
||||
dist1 = dist1.to(device)
|
||||
dist2 = dist2.to(device)
|
||||
idx1 = idx1.to(device)
|
||||
idx2 = idx2.to(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
||||
return dist1, dist2, idx1, idx2
|
||||
|
||||
class chamfer_3DDist_nograd(nn.Module):
|
||||
def __init__(self):
|
||||
super(chamfer_3DDist_nograd, self).__init__()
|
||||
|
||||
def forward(self, input1, input2):
|
||||
input1 = input1.contiguous()
|
||||
input2 = input2.contiguous()
|
||||
return chamfer_3DFunction_noGrad.apply(input1, input2)
|
14
third_party/ChamferDistancePytorch/chamfer3D/setup.py
vendored
Executable file
14
third_party/ChamferDistancePytorch/chamfer3D/setup.py
vendored
Executable file
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_3D',
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_3D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
|
||||
]),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
223
third_party/ChamferDistancePytorch/chamfer5D/chamfer5D.cu
vendored
Executable file
223
third_party/ChamferDistancePytorch/chamfer5D/chamfer5D.cu
vendored
Executable file
|
@ -0,0 +1,223 @@
|
|||
|
||||
#include <stdio.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
||||
|
||||
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
|
||||
const int batch=2048;
|
||||
__shared__ float buf[batch*5];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int k2=0;k2<m;k2+=batch){
|
||||
int end_k=min(m,k2+batch)-k2;
|
||||
for (int j=threadIdx.x;j<end_k*5;j+=blockDim.x){
|
||||
buf[j]=xyz2[(i*m+k2)*5+j];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz[(i*n+j)*5+0];
|
||||
float y1=xyz[(i*n+j)*5+1];
|
||||
float r1=xyz[(i*n+j)*5+2];
|
||||
float g1=xyz[(i*n+j)*5+3];
|
||||
float b1=xyz[(i*n+j)*5+4];
|
||||
int best_i=0;
|
||||
float best=0;
|
||||
int end_ka=end_k-(end_k&5);
|
||||
if (end_ka==batch){
|
||||
for (int k=0;k<batch;k+=4){
|
||||
{
|
||||
float x2=buf[k*5+0]-x1;
|
||||
float y2=buf[k*5+1]-y1;
|
||||
float r2=buf[k*5+2]-r1;
|
||||
float g2=buf[k*5+3]-g1;
|
||||
float b2=buf[k*5+4]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*5+5]-x1;
|
||||
float y2=buf[k*5+6]-y1;
|
||||
float r2=buf[k*5+7]-r1;
|
||||
float g2=buf[k*5+8]-g1;
|
||||
float b2=buf[k*5+9]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*5+10]-x1;
|
||||
float y2=buf[k*5+11]-y1;
|
||||
float r2=buf[k*5+12]-r1;
|
||||
float g2=buf[k*5+13]-g1;
|
||||
float b2=buf[k*5+14]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*5+15]-x1;
|
||||
float y2=buf[k*5+16]-y1;
|
||||
float r2=buf[k*5+17]-r1;
|
||||
float g2=buf[k*5+18]-g1;
|
||||
float b2=buf[k*5+19]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for (int k=0;k<end_ka;k+=4){
|
||||
{
|
||||
float x2=buf[k*5+0]-x1;
|
||||
float y2=buf[k*5+1]-y1;
|
||||
float r2=buf[k*5+2]-r1;
|
||||
float g2=buf[k*5+3]-g1;
|
||||
float b2=buf[k*5+4]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*5+5]-x1;
|
||||
float y2=buf[k*5+6]-y1;
|
||||
float r2=buf[k*5+7]-r1;
|
||||
float g2=buf[k*5+8]-g1;
|
||||
float b2=buf[k*5+9]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*5+10]-x1;
|
||||
float y2=buf[k*5+11]-y1;
|
||||
float r2=buf[k*5+12]-r1;
|
||||
float g2=buf[k*5+13]-g1;
|
||||
float b2=buf[k*5+14]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*5+15]-x1;
|
||||
float y2=buf[k*5+16]-y1;
|
||||
float r2=buf[k*5+17]-r1;
|
||||
float g2=buf[k*5+18]-g1;
|
||||
float b2=buf[k*5+19]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k=end_ka;k<end_k;k++){
|
||||
float x2=buf[k*5+0]-x1;
|
||||
float y2=buf[k*5+1]-y1;
|
||||
float r2=buf[k*5+2]-r1;
|
||||
float g2=buf[k*5+3]-g1;
|
||||
float b2=buf[k*5+4]-b1;
|
||||
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
if (k2==0 || result[(i*n+j)]>best){
|
||||
result[(i*n+j)]=best;
|
||||
result_i[(i*n+j)]=best_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
|
||||
}
|
||||
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz1[(i*n+j)*5+0];
|
||||
float y1=xyz1[(i*n+j)*5+1];
|
||||
float r1=xyz1[(i*n+j)*5+2];
|
||||
float g1=xyz1[(i*n+j)*5+3];
|
||||
float b1=xyz1[(i*n+j)*5+4];
|
||||
int j2=idx1[i*n+j];
|
||||
float x2=xyz2[(i*m+j2)*5+0];
|
||||
float y2=xyz2[(i*m+j2)*5+1];
|
||||
float r2=xyz2[(i*m+j2)*5+2];
|
||||
float g2=xyz2[(i*m+j2)*5+3];
|
||||
float b2=xyz2[(i*m+j2)*5+4];
|
||||
float g=grad_dist1[i*n+j]*2;
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*5+0]),g*(x1-x2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*5+1]),g*(y1-y2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*5+2]),g*(r1-r2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*5+3]),g*(g1-g2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*5+4]),g*(b1-b2));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*5+0]),-(g*(x1-x2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*5+1]),-(g*(y1-y2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*5+2]),-(g*(r1-r2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*5+3]),-(g*(g1-g2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*5+4]),-(g*(b1-b2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
|
||||
// cudaMemset(grad_xyz1,0,b*n*3*4);
|
||||
// cudaMemset(grad_xyz2,0,b*m*3*4);
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
}
|
33
third_party/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp
vendored
Executable file
33
third_party/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp
vendored
Executable file
|
@ -0,0 +1,33 @@
|
|||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
///TMP
|
||||
//#include "common.h"
|
||||
/// NOT TMP
|
||||
|
||||
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
|
||||
|
||||
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
|
||||
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
|
||||
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
|
||||
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
|
||||
}
|
82
third_party/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py
vendored
Normal file
82
third_party/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
|
||||
chamfer_found = importlib.find_loader("chamfer_5D") is not None
|
||||
if not chamfer_found:
|
||||
## Cool trick from https://github.com/chrdiller
|
||||
print("Jitting Chamfer 5D")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
build_path = cur_path.replace('chamfer5D', 'tmp')
|
||||
os.makedirs(build_path, exist_ok=True)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_5D = load(name="chamfer_5D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]),
|
||||
], build_directory=build_path)
|
||||
print("Loaded JIT 5D CUDA chamfer distance")
|
||||
|
||||
else:
|
||||
import chamfer_5D
|
||||
print("Loaded compiled 5D CUDA chamfer distance")
|
||||
|
||||
|
||||
# Chamfer's distance module @thibaultgroueix
|
||||
# GPU tensors only
|
||||
class chamfer_5DFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
batchsize, n, dim = xyz1.size()
|
||||
assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
_, m, dim = xyz2.size()
|
||||
assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
device = xyz1.device
|
||||
|
||||
device = xyz1.device
|
||||
|
||||
dist1 = torch.zeros(batchsize, n)
|
||||
dist2 = torch.zeros(batchsize, m)
|
||||
|
||||
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
||||
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
||||
|
||||
dist1 = dist1.to(device)
|
||||
dist2 = dist2.to(device)
|
||||
idx1 = idx1.to(device)
|
||||
idx2 = idx2.to(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
||||
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
||||
return dist1, dist2, idx1, idx2
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
|
||||
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
||||
graddist1 = graddist1.contiguous()
|
||||
graddist2 = graddist2.contiguous()
|
||||
device = graddist1.device
|
||||
|
||||
gradxyz1 = torch.zeros(xyz1.size())
|
||||
gradxyz2 = torch.zeros(xyz2.size())
|
||||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_5D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
class chamfer_5DDist(nn.Module):
|
||||
def __init__(self):
|
||||
super(chamfer_5DDist, self).__init__()
|
||||
|
||||
def forward(self, input1, input2):
|
||||
input1 = input1.contiguous()
|
||||
input2 = input2.contiguous()
|
||||
return chamfer_5DFunction.apply(input1, input2)
|
14
third_party/ChamferDistancePytorch/chamfer5D/setup.py
vendored
Executable file
14
third_party/ChamferDistancePytorch/chamfer5D/setup.py
vendored
Executable file
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_5D',
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_5D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']),
|
||||
]),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
237
third_party/ChamferDistancePytorch/chamfer6D/chamfer6D.cu
vendored
Executable file
237
third_party/ChamferDistancePytorch/chamfer6D/chamfer6D.cu
vendored
Executable file
|
@ -0,0 +1,237 @@
|
|||
|
||||
#include <stdio.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
||||
|
||||
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
|
||||
const int batch=2048;
|
||||
__shared__ float buf[batch*6];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int k2=0;k2<m;k2+=batch){
|
||||
int end_k=min(m,k2+batch)-k2;
|
||||
for (int j=threadIdx.x;j<end_k*6;j+=blockDim.x){
|
||||
buf[j]=xyz2[(i*m+k2)*6+j];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz[(i*n+j)*6+0];
|
||||
float y1=xyz[(i*n+j)*6+1];
|
||||
float z1=xyz[(i*n+j)*6+2];
|
||||
float nx1=xyz[(i*n+j)*6+3];
|
||||
float ny1=xyz[(i*n+j)*6+4];
|
||||
float nz1=xyz[(i*n+j)*6+5];
|
||||
int best_i=0;
|
||||
float best=0;
|
||||
int end_ka=end_k-(end_k&6);
|
||||
if (end_ka==batch){
|
||||
for (int k=0;k<batch;k+=4){
|
||||
{
|
||||
float x2=buf[k*6+0]-x1;
|
||||
float y2=buf[k*6+1]-y1;
|
||||
float z2=buf[k*6+2]-z1;
|
||||
float nx2=buf[k*6+3]-nx1;
|
||||
float ny2=buf[k*6+4]-ny1;
|
||||
float nz2=buf[k*6+5]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*6+6]-x1;
|
||||
float y2=buf[k*6+7]-y1;
|
||||
float z2=buf[k*6+8]-z1;
|
||||
float nx2=buf[k*6+9]-nx1;
|
||||
float ny2=buf[k*6+10]-ny1;
|
||||
float nz2=buf[k*6+11]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*6+12]-x1;
|
||||
float y2=buf[k*6+13]-y1;
|
||||
float z2=buf[k*6+14]-z1;
|
||||
float nx2=buf[k*6+15]-nx1;
|
||||
float ny2=buf[k*6+16]-ny1;
|
||||
float nz2=buf[k*6+17]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*6+18]-x1;
|
||||
float y2=buf[k*6+19]-y1;
|
||||
float z2=buf[k*6+20]-z1;
|
||||
float nx2=buf[k*6+21]-nx1;
|
||||
float ny2=buf[k*6+22]-ny1;
|
||||
float nz2=buf[k*6+23]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for (int k=0;k<end_ka;k+=4){
|
||||
{
|
||||
float x2=buf[k*6+0]-x1;
|
||||
float y2=buf[k*6+1]-y1;
|
||||
float z2=buf[k*6+2]-z1;
|
||||
float nx2=buf[k*6+3]-nx1;
|
||||
float ny2=buf[k*6+4]-ny1;
|
||||
float nz2=buf[k*6+5]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*6+6]-x1;
|
||||
float y2=buf[k*6+7]-y1;
|
||||
float z2=buf[k*6+8]-z1;
|
||||
float nx2=buf[k*6+9]-nx1;
|
||||
float ny2=buf[k*6+10]-ny1;
|
||||
float nz2=buf[k*6+11]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+1;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*6+12]-x1;
|
||||
float y2=buf[k*6+13]-y1;
|
||||
float z2=buf[k*6+14]-z1;
|
||||
float nx2=buf[k*6+15]-nx1;
|
||||
float ny2=buf[k*6+16]-ny1;
|
||||
float nz2=buf[k*6+17]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+2;
|
||||
}
|
||||
}
|
||||
{
|
||||
float x2=buf[k*6+18]-x1;
|
||||
float y2=buf[k*6+19]-y1;
|
||||
float z2=buf[k*6+20]-z1;
|
||||
float nx2=buf[k*6+21]-nx1;
|
||||
float ny2=buf[k*6+22]-ny1;
|
||||
float nz2=buf[k*6+23]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (d<best){
|
||||
best=d;
|
||||
best_i=k+k2+3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k=end_ka;k<end_k;k++){
|
||||
float x2=buf[k*6+0]-x1;
|
||||
float y2=buf[k*6+1]-y1;
|
||||
float z2=buf[k*6+2]-z1;
|
||||
float nx2=buf[k*6+3]-nx1;
|
||||
float ny2=buf[k*6+4]-ny1;
|
||||
float nz2=buf[k*6+5]-nz1;
|
||||
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
|
||||
if (k==0 || d<best){
|
||||
best=d;
|
||||
best_i=k+k2;
|
||||
}
|
||||
}
|
||||
if (k2==0 || result[(i*n+j)]>best){
|
||||
result[(i*n+j)]=best;
|
||||
result_i[(i*n+j)]=best_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
|
||||
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
|
||||
}
|
||||
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
||||
float x1=xyz1[(i*n+j)*6+0];
|
||||
float y1=xyz1[(i*n+j)*6+1];
|
||||
float z1=xyz1[(i*n+j)*6+2];
|
||||
float nx1=xyz1[(i*n+j)*6+3];
|
||||
float ny1=xyz1[(i*n+j)*6+4];
|
||||
float nz1=xyz1[(i*n+j)*6+5];
|
||||
int j2=idx1[i*n+j];
|
||||
float x2=xyz2[(i*m+j2)*6+0];
|
||||
float y2=xyz2[(i*m+j2)*6+1];
|
||||
float z2=xyz2[(i*m+j2)*6+2];
|
||||
float nx2=xyz2[(i*m+j2)*6+3];
|
||||
float ny2=xyz2[(i*m+j2)*6+4];
|
||||
float nz2=xyz2[(i*m+j2)*6+5];
|
||||
float g=grad_dist1[i*n+j]*2;
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*6+0]),g*(x1-x2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*6+1]),g*(y1-y2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*6+2]),g*(z1-z2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*6+3]),g*(nx1-nx2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*6+4]),g*(ny1-ny2));
|
||||
atomicAdd(&(grad_xyz1[(i*n+j)*6+5]),g*(nz1-nz2));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*6+0]),-(g*(x1-x2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*6+1]),-(g*(y1-y2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*6+2]),-(g*(z1-z2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*6+3]),-(g*(nx1-nx2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*6+4]),-(g*(ny1-ny2)));
|
||||
atomicAdd(&(grad_xyz2[(i*m+j2)*6+5]),-(g*(nz1-nz2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
|
||||
// cudaMemset(grad_xyz1,0,b*n*3*4);
|
||||
// cudaMemset(grad_xyz2,0,b*m*3*4);
|
||||
|
||||
const auto batch_size = xyz1.size(0);
|
||||
const auto n = xyz1.size(1); //num_points point cloud A
|
||||
const auto m = xyz2.size(1); //num_points point cloud B
|
||||
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
|
||||
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
|
||||
//THError("aborting");
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
|
||||
}
|
33
third_party/ChamferDistancePytorch/chamfer6D/chamfer_cuda.cpp
vendored
Executable file
33
third_party/ChamferDistancePytorch/chamfer6D/chamfer_cuda.cpp
vendored
Executable file
|
@ -0,0 +1,33 @@
|
|||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
///TMP
|
||||
//#include "common.h"
|
||||
/// NOT TMP
|
||||
|
||||
|
||||
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
|
||||
|
||||
|
||||
|
||||
|
||||
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
|
||||
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
|
||||
|
||||
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
|
||||
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
|
||||
}
|
82
third_party/ChamferDistancePytorch/chamfer6D/dist_chamfer_6D.py
vendored
Executable file
82
third_party/ChamferDistancePytorch/chamfer6D/dist_chamfer_6D.py
vendored
Executable file
|
@ -0,0 +1,82 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import torch
|
||||
import importlib
|
||||
import os
|
||||
|
||||
chamfer_found = importlib.find_loader("chamfer_6D") is not None
|
||||
if not chamfer_found:
|
||||
## Cool trick from https://github.com/chrdiller
|
||||
print("Jitting Chamfer 6D")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
build_path = cur_path.replace('chamfer6D', 'tmp')
|
||||
os.makedirs(build_path, exist_ok=True)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
chamfer_6D = load(name="chamfer_6D",
|
||||
sources=[
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
||||
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer6D.cu"]),
|
||||
], build_directory=build_path)
|
||||
print("Loaded JIT 6D CUDA chamfer distance")
|
||||
|
||||
else:
|
||||
import chamfer_6D
|
||||
print("Loaded compiled 6D CUDA chamfer distance")
|
||||
|
||||
|
||||
# Chamfer's distance module @thibaultgroueix
|
||||
# GPU tensors only
|
||||
class chamfer_6DFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
batchsize, n, dim = xyz1.size()
|
||||
assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
_, m, dim = xyz2.size()
|
||||
assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
||||
device = xyz1.device
|
||||
|
||||
device = xyz1.device
|
||||
|
||||
dist1 = torch.zeros(batchsize, n)
|
||||
dist2 = torch.zeros(batchsize, m)
|
||||
|
||||
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
||||
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
||||
|
||||
dist1 = dist1.to(device)
|
||||
dist2 = dist2.to(device)
|
||||
idx1 = idx1.to(device)
|
||||
idx2 = idx2.to(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
chamfer_6D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
||||
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
||||
return dist1, dist2, idx1, idx2
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
|
||||
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
||||
graddist1 = graddist1.contiguous()
|
||||
graddist2 = graddist2.contiguous()
|
||||
device = graddist1.device
|
||||
|
||||
gradxyz1 = torch.zeros(xyz1.size())
|
||||
gradxyz2 = torch.zeros(xyz2.size())
|
||||
|
||||
gradxyz1 = gradxyz1.to(device)
|
||||
gradxyz2 = gradxyz2.to(device)
|
||||
chamfer_6D.backward(
|
||||
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
||||
)
|
||||
return gradxyz1, gradxyz2
|
||||
|
||||
|
||||
class chamfer_6DDist(nn.Module):
|
||||
def __init__(self):
|
||||
super(chamfer_6DDist, self).__init__()
|
||||
|
||||
def forward(self, input1, input2):
|
||||
input1 = input1.contiguous()
|
||||
input2 = input2.contiguous()
|
||||
return chamfer_6DFunction.apply(input1, input2)
|
14
third_party/ChamferDistancePytorch/chamfer6D/setup.py
vendored
Executable file
14
third_party/ChamferDistancePytorch/chamfer6D/setup.py
vendored
Executable file
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='chamfer_6D',
|
||||
ext_modules=[
|
||||
CUDAExtension('chamfer_6D', [
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
||||
"/".join(__file__.split('/')[:-1] + ['chamfer6D.cu']),
|
||||
]),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
44
third_party/ChamferDistancePytorch/chamfer_python.py
vendored
Normal file
44
third_party/ChamferDistancePytorch/chamfer_python.py
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
|
||||
|
||||
def pairwise_dist(x, y):
|
||||
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
|
||||
rx = xx.diag().unsqueeze(0).expand_as(xx)
|
||||
ry = yy.diag().unsqueeze(0).expand_as(yy)
|
||||
P = rx.t() + ry - 2 * zz
|
||||
return P
|
||||
|
||||
|
||||
def NN_loss(x, y, dim=0):
|
||||
dist = pairwise_dist(x, y)
|
||||
values, indices = dist.min(dim=dim)
|
||||
return values.mean()
|
||||
|
||||
|
||||
def batched_pairwise_dist(a, b):
|
||||
x, y = a.double(), b.double()
|
||||
bs, num_points_x, points_dim = x.size()
|
||||
bs, num_points_y, points_dim = y.size()
|
||||
|
||||
xx = torch.pow(x, 2).sum(2)
|
||||
yy = torch.pow(y, 2).sum(2)
|
||||
zz = torch.bmm(x, y.transpose(2, 1))
|
||||
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
|
||||
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
|
||||
P = rx.transpose(2, 1) + ry - 2 * zz
|
||||
return P
|
||||
|
||||
def distChamfer(a, b):
|
||||
"""
|
||||
:param a: Pointclouds Batch x nul_points x dim
|
||||
:param b: Pointclouds Batch x nul_points x dim
|
||||
:return:
|
||||
-closest point on b of points from a
|
||||
-closest point on a of points from b
|
||||
-idx of closest point on b of points from a
|
||||
-idx of closest point on a of points from b
|
||||
Works for pointcloud of any dimension
|
||||
"""
|
||||
P = batched_pairwise_dist(a, b)
|
||||
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()
|
||||
|
17
third_party/ChamferDistancePytorch/fscore.py
vendored
Normal file
17
third_party/ChamferDistancePytorch/fscore.py
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
import torch
|
||||
|
||||
def fscore(dist1, dist2, threshold=0.001):
|
||||
"""
|
||||
Calculates the F-score between two point clouds with the corresponding threshold value.
|
||||
:param dist1: Batch, N-Points
|
||||
:param dist2: Batch, N-Points
|
||||
:param th: float
|
||||
:return: fscore, precision, recall
|
||||
"""
|
||||
# NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly.
|
||||
precision_1 = torch.mean((dist1 < threshold).float(), dim=1)
|
||||
precision_2 = torch.mean((dist2 < threshold).float(), dim=1)
|
||||
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
|
||||
fscore[torch.isnan(fscore)] = 0
|
||||
return fscore, precision_1, precision_2
|
||||
|
69
third_party/ChamferDistancePytorch/unit_test.py
vendored
Normal file
69
third_party/ChamferDistancePytorch/unit_test.py
vendored
Normal file
|
@ -0,0 +1,69 @@
|
|||
import torch, time
|
||||
import chamfer2D.dist_chamfer_2D
|
||||
import chamfer3D.dist_chamfer_3D
|
||||
import chamfer5D.dist_chamfer_5D
|
||||
import chamfer_python
|
||||
|
||||
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
|
||||
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
|
||||
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
|
||||
|
||||
from torch.autograd import Variable
|
||||
from fscore import fscore
|
||||
|
||||
def test_chamfer(distChamfer, dim):
|
||||
points1 = torch.rand(4, 100, dim).cuda()
|
||||
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
|
||||
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
|
||||
|
||||
loss = torch.sum(dist1)
|
||||
loss.backward()
|
||||
|
||||
mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2)
|
||||
d1 = (dist1 - mydist1) ** 2
|
||||
d2 = (dist2 - mydist2) ** 2
|
||||
assert (
|
||||
torch.mean(d1) + torch.mean(d2) < 0.00000001
|
||||
), "chamfer cuda and chamfer normal are not giving the same results"
|
||||
|
||||
xd1 = idx1 - myidx1
|
||||
xd2 = idx2 - myidx2
|
||||
assert (
|
||||
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
|
||||
), "chamfer cuda and chamfer normal are not giving the same results"
|
||||
print(f"fscore :", fscore(dist1, dist2))
|
||||
print("Unit test passed")
|
||||
|
||||
|
||||
def timings(distChamfer, dim):
|
||||
p1 = torch.rand(32, 2000, dim).cuda()
|
||||
p2 = torch.rand(32, 1000, dim).cuda()
|
||||
print("Timings : Start CUDA version")
|
||||
start = time.time()
|
||||
num_it = 100
|
||||
for i in range(num_it):
|
||||
points1 = Variable(p1, requires_grad=True)
|
||||
points2 = Variable(p2)
|
||||
mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2)
|
||||
loss = torch.sum(mydist1)
|
||||
loss.backward()
|
||||
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
||||
|
||||
|
||||
print("Timings : Start Pythonic version")
|
||||
start = time.time()
|
||||
for i in range(num_it):
|
||||
points1 = Variable(p1, requires_grad=True)
|
||||
points2 = Variable(p2)
|
||||
mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2)
|
||||
loss = torch.sum(mydist1)
|
||||
loss.backward()
|
||||
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
|
||||
|
||||
|
||||
|
||||
dims = [2,3,5]
|
||||
for i,cham in enumerate([cham2D, cham3D, cham5D]):
|
||||
print(f"testing Chamfer {dims[i]}D")
|
||||
test_chamfer(cham, dims[i])
|
||||
timings(cham, dims[i])
|
5
third_party/PyTorchEMD/.gitignore
vendored
Normal file
5
third_party/PyTorchEMD/.gitignore
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
__pycache__
|
||||
build
|
||||
dist
|
||||
emd_ext.egg-info
|
||||
*.so
|
34
third_party/PyTorchEMD/README.md
vendored
Normal file
34
third_party/PyTorchEMD/README.md
vendored
Normal file
|
@ -0,0 +1,34 @@
|
|||
* adapted from https://github.com/daerduoCarey/PyTorchEMD
|
||||
|
||||
---------------------------------
|
||||
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
|
||||
|
||||
## Dependency
|
||||
|
||||
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
|
||||
|
||||
## Usage
|
||||
|
||||
First compile using
|
||||
|
||||
python setup.py install
|
||||
|
||||
Then, copy the lib file out to the main directory,
|
||||
|
||||
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
|
||||
|
||||
Then, you can use it by simply
|
||||
|
||||
from emd import earth_mover_distance
|
||||
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
|
||||
|
||||
Check `test_emd_loss.py` for example.
|
||||
|
||||
## Author
|
||||
|
||||
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
0
third_party/PyTorchEMD/__init__.py
vendored
Executable file
0
third_party/PyTorchEMD/__init__.py
vendored
Executable file
21
third_party/PyTorchEMD/backend.py
vendored
Executable file
21
third_party/PyTorchEMD/backend.py
vendored
Executable file
|
@ -0,0 +1,21 @@
|
|||
import os
|
||||
import time
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if not os.path.exists(os.path.join(_src_path, 'build_dynamic')):
|
||||
os.makedirs(os.path.join(_src_path, 'build_dynamic'))
|
||||
tic = time.time()
|
||||
emd_cuda_dynamic = load(name='emd_ext',
|
||||
extra_cflags=['-O3', '-std=c++17'],
|
||||
## build_directory=os.path.join(_src_path, 'build_dynamic'),
|
||||
verbose=True,
|
||||
sources=[
|
||||
os.path.join(_src_path, f) for f in [
|
||||
'cuda/emd.cpp',
|
||||
'cuda/emd_kernel.cu',
|
||||
]
|
||||
])
|
||||
print('load emd_ext time: {:.3f}s'.format(time.time() - tic))
|
||||
__all__ = ['emd_cuda_dynamic']
|
29
third_party/PyTorchEMD/cuda/emd.cpp
vendored
Executable file
29
third_party/PyTorchEMD/cuda/emd.cpp
vendored
Executable file
|
@ -0,0 +1,29 @@
|
|||
#ifndef _EMD
|
||||
#define _EMD
|
||||
|
||||
#include <vector>
|
||||
#include <torch/extension.h>
|
||||
|
||||
//CUDA declarations
|
||||
at::Tensor ApproxMatchForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2);
|
||||
|
||||
at::Tensor MatchCostForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match);
|
||||
|
||||
std::vector<at::Tensor> MatchCostBackward(
|
||||
const at::Tensor grad_cost,
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
|
||||
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
|
||||
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
|
||||
}
|
||||
|
||||
#endif
|
398
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
Normal file
398
third_party/PyTorchEMD/cuda/emd_kernel.cu
vendored
Normal file
|
@ -0,0 +1,398 @@
|
|||
/**********************************
|
||||
* Original Author: Haoqiang Fan
|
||||
* Modified by: Kaichun Mo
|
||||
*********************************/
|
||||
|
||||
#ifndef _EMD_KERNEL
|
||||
#define _EMD_KERNEL
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
|
||||
#include <THC/THC.h>
|
||||
|
||||
#define CHECK_INPUT(x)
|
||||
|
||||
|
||||
/********************************
|
||||
* Forward kernel for approxmatch
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
|
||||
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
|
||||
scalar_t multiL,multiR;
|
||||
if (n>=m){
|
||||
multiL=1;
|
||||
multiR=n/m;
|
||||
}else{
|
||||
multiL=m/n;
|
||||
multiR=1;
|
||||
}
|
||||
const int Block=1024;
|
||||
__shared__ scalar_t buf[Block*4];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
|
||||
match[i*n*m+j]=0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x)
|
||||
remainL[j]=multiL;
|
||||
for (int j=threadIdx.x;j<m;j+=blockDim.x)
|
||||
remainR[j]=multiR;
|
||||
__syncthreads();
|
||||
for (int j=7;j>=-2;j--){
|
||||
scalar_t level=-powf(4.0f,j);
|
||||
if (j==-2){
|
||||
level=0;
|
||||
}
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
scalar_t x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
scalar_t suml=1e-9f;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0];
|
||||
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1];
|
||||
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2];
|
||||
buf[l*4+0]=x2;
|
||||
buf[l*4+1]=y2;
|
||||
buf[l*4+2]=z2;
|
||||
buf[l*4+3]=remainR[l0+l];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l=0;l<lend;l++){
|
||||
scalar_t x2=buf[l*4+0];
|
||||
scalar_t y2=buf[l*4+1];
|
||||
scalar_t z2=buf[l*4+2];
|
||||
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
|
||||
scalar_t w=__expf(d)*buf[l*4+3];
|
||||
suml+=w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k<n)
|
||||
ratioL[k]=remainL[k]/suml;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int l0=0;l0<m;l0+=blockDim.x){
|
||||
int l=l0+threadIdx.x;
|
||||
scalar_t x2=0,y2=0,z2=0;
|
||||
if (l<m){
|
||||
x2=xyz2[i*m*3+l*3+0];
|
||||
y2=xyz2[i*m*3+l*3+1];
|
||||
z2=xyz2[i*m*3+l*3+2];
|
||||
}
|
||||
scalar_t sumr=0;
|
||||
for (int k0=0;k0<n;k0+=Block){
|
||||
int kend=min(n,k0+Block)-k0;
|
||||
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
|
||||
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
|
||||
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
|
||||
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
|
||||
buf[k*4+3]=ratioL[k0+k];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k=0;k<kend;k++){
|
||||
scalar_t x1=buf[k*4+0];
|
||||
scalar_t y1=buf[k*4+1];
|
||||
scalar_t z1=buf[k*4+2];
|
||||
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
|
||||
sumr+=w;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (l<m){
|
||||
sumr*=remainR[l];
|
||||
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
|
||||
ratioR[l]=consumption*remainR[l];
|
||||
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
scalar_t x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
scalar_t suml=0;
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
|
||||
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
|
||||
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
|
||||
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
|
||||
buf[l*4+3]=ratioR[l0+l];
|
||||
}
|
||||
__syncthreads();
|
||||
scalar_t rl=ratioL[k];
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
scalar_t x2=buf[l*4+0];
|
||||
scalar_t y2=buf[l*4+1];
|
||||
scalar_t z2=buf[l*4+2];
|
||||
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
|
||||
match[i*n*m+(l0+l)*n+k]+=w;
|
||||
suml+=w;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (k<n)
|
||||
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
|
||||
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
|
||||
//}
|
||||
|
||||
/* ApproxMatch forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
Output:
|
||||
match: (B, N2, N1)
|
||||
*/
|
||||
at::Tensor ApproxMatchForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2){
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto match = at::zeros({b, m, n}, xyz1.type());
|
||||
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
|
||||
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
|
||||
}));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return match;
|
||||
}
|
||||
|
||||
|
||||
/********************************
|
||||
* Forward kernel for matchcost
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
|
||||
__shared__ scalar_t allsum[512];
|
||||
const int Block=1024;
|
||||
__shared__ scalar_t buf[Block*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
scalar_t subsum=0;
|
||||
for (int k0=0;k0<n;k0+=blockDim.x){
|
||||
int k=k0+threadIdx.x;
|
||||
scalar_t x1=0,y1=0,z1=0;
|
||||
if (k<n){
|
||||
x1=xyz1[i*n*3+k*3+0];
|
||||
y1=xyz1[i*n*3+k*3+1];
|
||||
z1=xyz1[i*n*3+k*3+2];
|
||||
}
|
||||
for (int l0=0;l0<m;l0+=Block){
|
||||
int lend=min(m,l0+Block)-l0;
|
||||
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
|
||||
buf[l]=xyz2[i*m*3+l0*3+l];
|
||||
__syncthreads();
|
||||
if (k<n){
|
||||
for (int l=0;l<lend;l++){
|
||||
scalar_t x2=buf[l*3+0];
|
||||
scalar_t y2=buf[l*3+1];
|
||||
scalar_t z2=buf[l*3+2];
|
||||
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
|
||||
subsum+=d*match[i*n*m+(l0+l)*n+k];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
allsum[threadIdx.x]=subsum;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
__syncthreads();
|
||||
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
|
||||
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0)
|
||||
out[i]=allsum[0];
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
|
||||
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
|
||||
//}
|
||||
|
||||
/* MatchCost forward interface
|
||||
Input:
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
match: (B, N2, N1)
|
||||
Output:
|
||||
cost: (B)
|
||||
*/
|
||||
at::Tensor MatchCostForward(
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match){
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto cost = at::zeros({b}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
|
||||
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
|
||||
}));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
/********************************
|
||||
* matchcostgrad2 kernel
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
|
||||
__shared__ scalar_t sum_grad[256*3];
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
int kbeg=m*blockIdx.y/gridDim.y;
|
||||
int kend=m*(blockIdx.y+1)/gridDim.y;
|
||||
for (int k=kbeg;k<kend;k++){
|
||||
scalar_t x2=xyz2[(i*m+k)*3+0];
|
||||
scalar_t y2=xyz2[(i*m+k)*3+1];
|
||||
scalar_t z2=xyz2[(i*m+k)*3+2];
|
||||
scalar_t subsumx=0,subsumy=0,subsumz=0;
|
||||
for (int j=threadIdx.x;j<n;j+=blockDim.x){
|
||||
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
|
||||
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
|
||||
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
|
||||
scalar_t d=match[i*n*m+k*n+j]*2;
|
||||
subsumx+=x1*d;
|
||||
subsumy+=y1*d;
|
||||
subsumz+=z1*d;
|
||||
}
|
||||
sum_grad[threadIdx.x*3+0]=subsumx;
|
||||
sum_grad[threadIdx.x*3+1]=subsumy;
|
||||
sum_grad[threadIdx.x*3+2]=subsumz;
|
||||
for (int j=1;j<blockDim.x;j<<=1){
|
||||
__syncthreads();
|
||||
int j1=threadIdx.x;
|
||||
int j2=threadIdx.x+j;
|
||||
if ((j1&j)==0 && j2<blockDim.x){
|
||||
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
|
||||
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
|
||||
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x==0){
|
||||
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
|
||||
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
|
||||
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/********************************
|
||||
* matchcostgrad1 kernel
|
||||
*********************************/
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
|
||||
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
||||
for (int l=threadIdx.x;l<n;l+=blockDim.x){
|
||||
scalar_t x1=xyz1[i*n*3+l*3+0];
|
||||
scalar_t y1=xyz1[i*n*3+l*3+1];
|
||||
scalar_t z1=xyz1[i*n*3+l*3+2];
|
||||
scalar_t dx=0,dy=0,dz=0;
|
||||
for (int k=0;k<m;k++){
|
||||
scalar_t x2=xyz2[i*m*3+k*3+0];
|
||||
scalar_t y2=xyz2[i*m*3+k*3+1];
|
||||
scalar_t z2=xyz2[i*m*3+k*3+2];
|
||||
scalar_t d=match[i*n*m+k*n+l]*2;
|
||||
dx+=(x1-x2)*d;
|
||||
dy+=(y1-y2)*d;
|
||||
dz+=(z1-z2)*d;
|
||||
}
|
||||
grad1[i*n*3+l*3+0]=dx*grad_cost[i];
|
||||
grad1[i*n*3+l*3+1]=dy*grad_cost[i];
|
||||
grad1[i*n*3+l*3+2]=dz*grad_cost[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
|
||||
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
|
||||
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
|
||||
//}
|
||||
|
||||
|
||||
/* MatchCost backward interface
|
||||
Input:
|
||||
grad_cost: (B) # gradients on cost
|
||||
xyz1: (B, N1, 3) # dataset_points
|
||||
xyz2: (B, N2, 3) # query_points
|
||||
match: (B, N2, N1)
|
||||
Output:
|
||||
grad1: (B, N1, 3)
|
||||
grad2: (B, N2, 3)
|
||||
*/
|
||||
std::vector<at::Tensor> MatchCostBackward(
|
||||
const at::Tensor grad_cost,
|
||||
const at::Tensor xyz1,
|
||||
const at::Tensor xyz2,
|
||||
const at::Tensor match){
|
||||
const auto b = xyz1.size(0);
|
||||
const auto n = xyz1.size(1);
|
||||
const auto m = xyz2.size(1);
|
||||
|
||||
CHECK_EQ(xyz2.size(0), b);
|
||||
CHECK_EQ(xyz1.size(2), 3);
|
||||
CHECK_EQ(xyz2.size(2), 3);
|
||||
CHECK_INPUT(xyz1);
|
||||
CHECK_INPUT(xyz2);
|
||||
|
||||
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
|
||||
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
|
||||
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
|
||||
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
|
||||
}));
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return std::vector<at::Tensor>({grad1, grad2});
|
||||
}
|
||||
|
||||
#endif
|
52
third_party/PyTorchEMD/emd.py
vendored
Executable file
52
third_party/PyTorchEMD/emd.py
vendored
Executable file
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
# from backend import emd_cuda_dynamic as emd_cuda # jit compiling
|
||||
from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda
|
||||
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_cost):
|
||||
xyz1, xyz2, match = ctx.saved_tensors
|
||||
grad_cost = grad_cost.contiguous()
|
||||
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (b, 3, n1)
|
||||
xyz2 (torch.Tensor): (b, 3, n1)
|
||||
transpose (bool): whether to transpose inputs as it might be BCN format.
|
||||
Extensions only support BNC format.
|
||||
|
||||
Returns:
|
||||
cost (torch.Tensor): (b)
|
||||
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
# xyz1: B,N,3
|
||||
N = xyz1.shape[1]
|
||||
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N)
|
||||
return cost
|
||||
|
9
third_party/PyTorchEMD/emd_cuda.py
vendored
Normal file
9
third_party/PyTorchEMD/emd_cuda.py
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
def __bootstrap__():
|
||||
global __bootstrap__, __loader__, __file__
|
||||
import sys, pkg_resources, importlib.util
|
||||
__file__ = pkg_resources.resource_filename(__name__, 'emd_cuda.cpython-38-x86_64-linux-gnu.so')
|
||||
__loader__ = None; del __bootstrap__, __loader__
|
||||
spec = importlib.util.spec_from_file_location(__name__,__file__)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
__bootstrap__()
|
45
third_party/PyTorchEMD/emd_nograd.py
vendored
Normal file
45
third_party/PyTorchEMD/emd_nograd.py
vendored
Normal file
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
#import emd_cuda
|
||||
# from evaluation.PyTorchEMD import emd_cuda
|
||||
from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda
|
||||
|
||||
|
||||
class EarthMoverDistanceFunctionNoGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
# ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
|
||||
def earth_mover_distance_nograd(xyz1, xyz2, transpose=True):
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (b, 3, n1)
|
||||
xyz2 (torch.Tensor): (b, 3, n1)
|
||||
transpose (bool): whether to transpose inputs as it might be BCN format.
|
||||
Extensions only support BNC format.
|
||||
|
||||
Returns:
|
||||
cost (torch.Tensor): (b)
|
||||
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
# xyz1: B,N,3
|
||||
N = xyz1.shape[1]
|
||||
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
|
||||
#print('xyz1: ', xyz1.shape, xyz2.shape, xyz1.min(), xyz1.max(), xyz2.min(), xyz2.max())
|
||||
cost = EarthMoverDistanceFunctionNoGrad.apply(xyz1, xyz2) / float(N)
|
||||
return cost
|
||||
|
49
third_party/PyTorchEMD/emd_static.py
vendored
Executable file
49
third_party/PyTorchEMD/emd_static.py
vendored
Executable file
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
import emd_cuda
|
||||
|
||||
|
||||
class EarthMoverDistanceFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1, xyz2):
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
|
||||
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
|
||||
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
|
||||
ctx.save_for_backward(xyz1, xyz2, match)
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_cost):
|
||||
xyz1, xyz2, match = ctx.saved_tensors
|
||||
grad_cost = grad_cost.contiguous()
|
||||
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
def earth_mover_distance(xyz1, xyz2, transpose=True):
|
||||
"""Earth Mover Distance (Approx)
|
||||
|
||||
Args:
|
||||
xyz1 (torch.Tensor): (b, 3, n1)
|
||||
xyz2 (torch.Tensor): (b, 3, n1)
|
||||
transpose (bool): whether to transpose inputs as it might be BCN format.
|
||||
Extensions only support BNC format.
|
||||
|
||||
Returns:
|
||||
cost (torch.Tensor): (b)
|
||||
|
||||
"""
|
||||
if xyz1.dim() == 2:
|
||||
xyz1 = xyz1.unsqueeze(0)
|
||||
if xyz2.dim() == 2:
|
||||
xyz2 = xyz2.unsqueeze(0)
|
||||
if transpose:
|
||||
xyz1 = xyz1.transpose(1, 2)
|
||||
xyz2 = xyz2.transpose(1, 2)
|
||||
# xyz1: B,N,3
|
||||
N = xyz1.shape[1]
|
||||
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
|
||||
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N)
|
||||
return cost
|
||||
|
27
third_party/PyTorchEMD/setup.py
vendored
Executable file
27
third_party/PyTorchEMD/setup.py
vendored
Executable file
|
@ -0,0 +1,27 @@
|
|||
"""Setup extension
|
||||
|
||||
Notes:
|
||||
If extra_compile_args is provided, you need to provide different instances for different extensions.
|
||||
Refer to https://github.com/pytorch/pytorch/issues/20169
|
||||
|
||||
"""
|
||||
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
setup(
|
||||
name='emd_ext',
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='emd_cuda',
|
||||
sources=[
|
||||
'cuda/emd.cpp',
|
||||
'cuda/emd_kernel.cu',
|
||||
],
|
||||
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
44
third_party/PyTorchEMD/test_emd_loss.py
vendored
Normal file
44
third_party/PyTorchEMD/test_emd_loss.py
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from emd import earth_mover_distance
|
||||
|
||||
# gt
|
||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p1 = p1.repeat(3, 1, 1)
|
||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p2 = p2.repeat(3, 1, 1)
|
||||
print(p1)
|
||||
print(p2)
|
||||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
|
||||
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
|
||||
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
|
||||
print('gt_dist: ', gt_dist)
|
||||
|
||||
gt_dist.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
||||
# emd
|
||||
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p1 = p1.repeat(3, 1, 1)
|
||||
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
|
||||
p2 = p2.repeat(3, 1, 1)
|
||||
print(p1)
|
||||
print(p2)
|
||||
p1.requires_grad = True
|
||||
p2.requires_grad = True
|
||||
|
||||
d = earth_mover_distance(p1, p2, transpose=False)
|
||||
print(d)
|
||||
|
||||
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
|
||||
print(loss)
|
||||
|
||||
loss.backward()
|
||||
print(p1.grad)
|
||||
print(p2.grad)
|
||||
|
21
third_party/pvcnn/LICENSE
vendored
Normal file
21
third_party/pvcnn/LICENSE
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2018 Zhijian Liu, Haotian Tang, Yujun Lin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
2
third_party/pvcnn/README.md
vendored
Normal file
2
third_party/pvcnn/README.md
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
* all the code under this folder is based on the code under https://github.com/mit-han-lab/pvcnn/tree/master/modules
|
||||
|
7
third_party/pvcnn/functional/__init__.py
vendored
Normal file
7
third_party/pvcnn/functional/__init__.py
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
from third_party.pvcnn.functional.ball_query import ball_query
|
||||
from third_party.pvcnn.functional.devoxelization import trilinear_devoxelize
|
||||
from third_party.pvcnn.functional.grouping import grouping
|
||||
from third_party.pvcnn.functional.interpolatation import nearest_neighbor_interpolate
|
||||
from third_party.pvcnn.functional.loss import kl_loss, huber_loss
|
||||
from third_party.pvcnn.functional.sampling import gather, furthest_point_sample, logits_mask
|
||||
from third_party.pvcnn.functional.voxelization import avg_voxelize
|
29
third_party/pvcnn/functional/backend.py
vendored
Normal file
29
third_party/pvcnn/functional/backend.py
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if not os.path.exists(os.path.join(_src_path, 'build')):
|
||||
os.makedirs(os.path.join(_src_path, 'build'))
|
||||
_backend = load(name='_pvcnn_backend',
|
||||
extra_cflags=['-O3', '-std=c++17'],
|
||||
verbose=True,
|
||||
sources=[
|
||||
os.path.join(_src_path, 'src', f) for f in [
|
||||
'ball_query/ball_query.cpp',
|
||||
'ball_query/ball_query.cu',
|
||||
'grouping/grouping.cpp',
|
||||
'grouping/grouping.cu',
|
||||
'interpolate/neighbor_interpolate.cpp',
|
||||
'interpolate/neighbor_interpolate.cu',
|
||||
'interpolate/trilinear_devox.cpp',
|
||||
'interpolate/trilinear_devox.cu',
|
||||
'sampling/sampling.cpp',
|
||||
'sampling/sampling.cu',
|
||||
'voxelization/vox.cpp',
|
||||
'voxelization/vox.cu',
|
||||
'bindings.cpp',
|
||||
]
|
||||
])
|
||||
|
||||
__all__ = ['_backend']
|
20
third_party/pvcnn/functional/ball_query.py
vendored
Normal file
20
third_party/pvcnn/functional/ball_query.py
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
from torch.autograd import Function
|
||||
|
||||
from third_party.pvcnn.functional.backend import _backend
|
||||
|
||||
__all__ = ['ball_query']
|
||||
|
||||
|
||||
def ball_query(centers_coords, points_coords, radius, num_neighbors):
|
||||
"""
|
||||
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
||||
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
||||
:param radius: float, radius of ball query
|
||||
:param num_neighbors: int, maximum number of neighbors
|
||||
:return:
|
||||
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
|
||||
"""
|
||||
centers_coords = centers_coords[:,:3].contiguous()
|
||||
points_coords = points_coords[:,:3].contiguous()
|
||||
return _backend.ball_query(centers_coords, points_coords, radius,
|
||||
num_neighbors)
|
45
third_party/pvcnn/functional/devoxelization.py
vendored
Normal file
45
third_party/pvcnn/functional/devoxelization.py
vendored
Normal file
|
@ -0,0 +1,45 @@
|
|||
from torch.autograd import Function
|
||||
|
||||
from third_party.pvcnn.functional.backend import _backend
|
||||
|
||||
__all__ = ['trilinear_devoxelize']
|
||||
|
||||
|
||||
class TrilinearDevoxelization(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, features, coords, resolution, is_training=True):
|
||||
"""
|
||||
:param ctx:
|
||||
:param coords: the coordinates of points, FloatTensor[B, 3, N]
|
||||
:param features: FloatTensor[B, C, R, R, R]
|
||||
:param resolution: int, the voxel resolution
|
||||
:param is_training: bool, training mode
|
||||
:return:
|
||||
FloatTensor[B, C, N]
|
||||
"""
|
||||
B, C = features.shape[:2]
|
||||
features = features.contiguous().view(B, C, -1)
|
||||
coords = coords[:,:3].contiguous()
|
||||
outs, inds, wgts = _backend.trilinear_devoxelize_forward(
|
||||
resolution, is_training, coords, features)
|
||||
if is_training:
|
||||
ctx.save_for_backward(inds, wgts)
|
||||
ctx.r = resolution
|
||||
return outs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
:param ctx:
|
||||
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
|
||||
:return:
|
||||
gradient of inputs, FloatTensor[B, C, R, R, R]
|
||||
"""
|
||||
inds, wgts = ctx.saved_tensors
|
||||
grad_inputs = _backend.trilinear_devoxelize_backward(
|
||||
grad_output.contiguous(), inds, wgts, ctx.r)
|
||||
return grad_inputs.view(grad_output.size(0), grad_output.size(1),
|
||||
ctx.r, ctx.r, ctx.r), None, None, None
|
||||
|
||||
|
||||
trilinear_devoxelize = TrilinearDevoxelization.apply
|
33
third_party/pvcnn/functional/grouping.py
vendored
Normal file
33
third_party/pvcnn/functional/grouping.py
vendored
Normal file
|
@ -0,0 +1,33 @@
|
|||
from torch.autograd import Function
|
||||
|
||||
# from modules.functional.backend import _backend
|
||||
from third_party.pvcnn.functional.backend import _backend
|
||||
|
||||
__all__ = ['grouping']
|
||||
|
||||
|
||||
class Grouping(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, features, indices):
|
||||
"""
|
||||
:param ctx:
|
||||
:param features: features of points, FloatTensor[B, C, N]
|
||||
:param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
|
||||
:return:
|
||||
grouped_features: grouped features, FloatTensor[B, C, M, U]
|
||||
"""
|
||||
features = features.contiguous()
|
||||
indices = indices.contiguous()
|
||||
ctx.save_for_backward(indices)
|
||||
ctx.num_points = features.size(-1)
|
||||
return _backend.grouping_forward(features, indices)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
grad_features = _backend.grouping_backward(grad_output.contiguous(),
|
||||
indices, ctx.num_points)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
grouping = Grouping.apply
|
54
third_party/pvcnn/functional/interpolatation.py
vendored
Normal file
54
third_party/pvcnn/functional/interpolatation.py
vendored
Normal file
|
@ -0,0 +1,54 @@
|
|||
from torch.autograd import Function
|
||||
|
||||
# from modules.functional.backend import _backend
|
||||
from third_party.pvcnn.functional.backend import _backend
|
||||
import torch
|
||||
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||||
|
||||
__all__ = ['nearest_neighbor_interpolate']
|
||||
|
||||
|
||||
class NeighborInterpolation(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, points_coords, centers_coords, centers_features):
|
||||
"""
|
||||
:param ctx:
|
||||
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
|
||||
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
|
||||
:param centers_features: features of centers, FloatTensor[B, C, M]
|
||||
:return:
|
||||
points_features: features of points, FloatTensor[B, C, N]
|
||||
"""
|
||||
centers_coords = centers_coords[:,:3].contiguous()
|
||||
points_coords = points_coords[:,:3].contiguous()
|
||||
centers_features = centers_features.contiguous()
|
||||
points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
|
||||
points_coords, centers_coords, centers_features)
|
||||
ctx.save_for_backward(indices, weights)
|
||||
ctx.num_centers = centers_coords.size(-1)
|
||||
return points_features
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
indices, weights = ctx.saved_tensors
|
||||
grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
|
||||
grad_output.contiguous(), indices, weights, ctx.num_centers)
|
||||
return None, None, grad_centers_features
|
||||
|
||||
|
||||
nearest_neighbor_interpolate = NeighborInterpolation.apply
|
||||
|
||||
#def nearest_neighbor_interpolate(points_coords, centers_coords, centers_features):
|
||||
# # points_coords: (B,6, 64)
|
||||
# # centers_coords: (B,6, 16)
|
||||
# # centers_features: (B,128,16)
|
||||
# # interpolated_features: (B,128,64)
|
||||
# B = points_coords.shape[0]
|
||||
# D = centers_features.shape[1]
|
||||
# N = points_coords.shape[2]
|
||||
# output = torch.zeros(B,D,N).to(points_coords.shape)
|
||||
# for b in range(B):
|
||||
# for n in range(N):
|
||||
# points_coords_cur = points_coords
|
18
third_party/pvcnn/functional/loss.py
vendored
Normal file
18
third_party/pvcnn/functional/loss.py
vendored
Normal file
|
@ -0,0 +1,18 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['kl_loss', 'huber_loss']
|
||||
|
||||
|
||||
def kl_loss(x, y):
|
||||
x = F.softmax(x.detach(), dim=1)
|
||||
y = F.log_softmax(y, dim=1)
|
||||
return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1))
|
||||
|
||||
|
||||
def huber_loss(error, delta):
|
||||
abs_error = torch.abs(error)
|
||||
quadratic = torch.min(abs_error,
|
||||
torch.full_like(abs_error, fill_value=delta))
|
||||
losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic)
|
||||
return torch.mean(losses)
|
100
third_party/pvcnn/functional/sampling.py
vendored
Normal file
100
third_party/pvcnn/functional/sampling.py
vendored
Normal file
|
@ -0,0 +1,100 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
# from modules.functional.backend import _backend
|
||||
from third_party.pvcnn.functional.backend import _backend
|
||||
|
||||
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
|
||||
|
||||
|
||||
class Gather(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, features, indices):
|
||||
"""
|
||||
Gather
|
||||
:param ctx:
|
||||
:param features: features of points, FloatTensor[B, C, N]
|
||||
:param indices: centers' indices in points, IntTensor[b, m]
|
||||
:return:
|
||||
centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
|
||||
"""
|
||||
features = features.contiguous()
|
||||
indices = indices.int().contiguous()
|
||||
ctx.save_for_backward(indices)
|
||||
ctx.num_points = features.size(-1)
|
||||
return _backend.gather_features_forward(features, indices)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
grad_features = _backend.gather_features_backward(
|
||||
grad_output.contiguous(), indices, ctx.num_points)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
gather = Gather.apply
|
||||
|
||||
|
||||
def furthest_point_sample(coords, num_samples, normals=None):
|
||||
"""
|
||||
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
||||
minimum distance to the sampled point set
|
||||
:param coords: coordinates of points, FloatTensor[B, 3, N]
|
||||
:param num_samples: int, M
|
||||
:return:
|
||||
center_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
|
||||
"""
|
||||
assert(len(coords.shape) == 3 and coords.shape[1] == 3), f'expect input as B,3,N; get: {coords.shape}'
|
||||
coords = coords.contiguous()
|
||||
indices = _backend.furthest_point_sampling(coords, num_samples)
|
||||
centers_coords = gather(coords, indices)
|
||||
if normals is not None:
|
||||
center_normals = gather(normals, indices)
|
||||
return centers_coords if normals is None else (centers_coords, center_normals)
|
||||
|
||||
|
||||
def logits_mask(coords, logits, num_points_per_object):
|
||||
"""
|
||||
Use logits to sample points
|
||||
:param coords: coords of points, FloatTensor[B, 3, N]
|
||||
:param logits: binary classification logits, FloatTensor[B, 2, N]
|
||||
:param num_points_per_object: M, #points per object after masking, int
|
||||
:return:
|
||||
selected_coords: FloatTensor[B, 3, M]
|
||||
masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
|
||||
mask: mask to select points, BoolTensor[B, N]
|
||||
"""
|
||||
batch_size, _, num_points = coords.shape
|
||||
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
|
||||
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
|
||||
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
|
||||
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(
|
||||
num_candidates, torch.ones_like(num_candidates)).float() # [B, C]
|
||||
selected_indices = torch.zeros((batch_size, num_points_per_object),
|
||||
device=coords.device,
|
||||
dtype=torch.int32)
|
||||
for i in range(batch_size):
|
||||
current_mask = mask[i] # [N]
|
||||
current_candidates = current_mask.nonzero().view(-1)
|
||||
current_num_candidates = current_candidates.numel()
|
||||
if current_num_candidates >= num_points_per_object:
|
||||
choices = np.random.choice(current_num_candidates,
|
||||
num_points_per_object,
|
||||
replace=False)
|
||||
selected_indices[i] = current_candidates[choices]
|
||||
elif current_num_candidates > 0:
|
||||
choices = np.concatenate([
|
||||
np.arange(current_num_candidates).repeat(
|
||||
num_points_per_object // current_num_candidates),
|
||||
np.random.choice(current_num_candidates,
|
||||
num_points_per_object %
|
||||
current_num_candidates,
|
||||
replace=False)
|
||||
])
|
||||
np.random.shuffle(choices)
|
||||
selected_indices[i] = current_candidates[choices]
|
||||
selected_coords = gather(
|
||||
masked_coords - masked_coords_mean.view(batch_size, -1, 1),
|
||||
selected_indices)
|
||||
return selected_coords, masked_coords_mean, mask
|
30
third_party/pvcnn/functional/src/ball_query/ball_query.cpp
vendored
Normal file
30
third_party/pvcnn/functional/src/ball_query/ball_query.cpp
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
#include "ball_query.hpp"
|
||||
#include "ball_query.cuh"
|
||||
|
||||
#include "../utils.hpp"
|
||||
|
||||
at::Tensor ball_query_forward(at::Tensor centers_coords,
|
||||
at::Tensor points_coords, const float radius,
|
||||
const int num_neighbors) {
|
||||
CHECK_CUDA(centers_coords);
|
||||
CHECK_CUDA(points_coords);
|
||||
CHECK_CONTIGUOUS(centers_coords);
|
||||
CHECK_CONTIGUOUS(points_coords);
|
||||
CHECK_IS_FLOAT(centers_coords);
|
||||
CHECK_IS_FLOAT(points_coords);
|
||||
|
||||
int b = centers_coords.size(0);
|
||||
int m = centers_coords.size(2);
|
||||
int n = points_coords.size(2);
|
||||
|
||||
at::Tensor neighbors_indices = torch::zeros(
|
||||
{b, m, num_neighbors},
|
||||
at::device(centers_coords.device()).dtype(at::ScalarType::Int));
|
||||
|
||||
ball_query(b, n, m, radius * radius, num_neighbors,
|
||||
centers_coords.data_ptr<float>(),
|
||||
points_coords.data_ptr<float>(),
|
||||
neighbors_indices.data_ptr<int>());
|
||||
|
||||
return neighbors_indices;
|
||||
}
|
59
third_party/pvcnn/functional/src/ball_query/ball_query.cu
vendored
Normal file
59
third_party/pvcnn/functional/src/ball_query/ball_query.cu
vendored
Normal file
|
@ -0,0 +1,59 @@
|
|||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "../cuda_utils.cuh"
|
||||
|
||||
/*
|
||||
Function: ball query
|
||||
Args:
|
||||
b : batch size
|
||||
n : number of points in point clouds
|
||||
m : number of query centers
|
||||
r2 : ball query radius ** 2
|
||||
u : maximum number of neighbors
|
||||
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
|
||||
points_coords : coordinates of points, FloatTensor[b, 3, n]
|
||||
neighbors_indices : neighbor indices in points, IntTensor[b, m, u]
|
||||
*/
|
||||
__global__ void ball_query_kernel(int b, int n, int m, float r2, int u,
|
||||
const float *__restrict__ centers_coords,
|
||||
const float *__restrict__ points_coords,
|
||||
int *__restrict__ neighbors_indices) {
|
||||
int batch_index = blockIdx.x;
|
||||
int index = threadIdx.x;
|
||||
int stride = blockDim.x;
|
||||
points_coords += batch_index * n * 3;
|
||||
centers_coords += batch_index * m * 3;
|
||||
neighbors_indices += batch_index * m * u;
|
||||
|
||||
for (int j = index; j < m; j += stride) {
|
||||
float center_x = centers_coords[j];
|
||||
float center_y = centers_coords[j + m];
|
||||
float center_z = centers_coords[j + m + m];
|
||||
for (int k = 0, cnt = 0; k < n && cnt < u; ++k) {
|
||||
float dx = center_x - points_coords[k];
|
||||
float dy = center_y - points_coords[k + n];
|
||||
float dz = center_z - points_coords[k + n + n];
|
||||
float d2 = dx * dx + dy * dy + dz * dz;
|
||||
if (d2 < r2) {
|
||||
if (cnt == 0) {
|
||||
for (int v = 0; v < u; ++v) {
|
||||
neighbors_indices[j * u + v] = k;
|
||||
}
|
||||
}
|
||||
neighbors_indices[j * u + cnt] = k;
|
||||
++cnt;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ball_query(int b, int n, int m, float r2, int u,
|
||||
const float *centers_coords, const float *points_coords,
|
||||
int *neighbors_indices) {
|
||||
ball_query_kernel<<<b, optimal_num_threads(m), 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, n, m, r2, u, centers_coords, points_coords, neighbors_indices);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
8
third_party/pvcnn/functional/src/ball_query/ball_query.cuh
vendored
Normal file
8
third_party/pvcnn/functional/src/ball_query/ball_query.cuh
vendored
Normal file
|
@ -0,0 +1,8 @@
|
|||
#ifndef _BALL_QUERY_CUH
|
||||
#define _BALL_QUERY_CUH
|
||||
|
||||
void ball_query(int b, int n, int m, float r2, int u,
|
||||
const float *centers_coords, const float *points_coords,
|
||||
int *neighbors_indices);
|
||||
|
||||
#endif
|
10
third_party/pvcnn/functional/src/ball_query/ball_query.hpp
vendored
Normal file
10
third_party/pvcnn/functional/src/ball_query/ball_query.hpp
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
#ifndef _BALL_QUERY_HPP
|
||||
#define _BALL_QUERY_HPP
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor ball_query_forward(at::Tensor centers_coords,
|
||||
at::Tensor points_coords, const float radius,
|
||||
const int num_neighbors);
|
||||
|
||||
#endif
|
37
third_party/pvcnn/functional/src/bindings.cpp
vendored
Normal file
37
third_party/pvcnn/functional/src/bindings.cpp
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "ball_query/ball_query.hpp"
|
||||
#include "grouping/grouping.hpp"
|
||||
#include "interpolate/neighbor_interpolate.hpp"
|
||||
#include "interpolate/trilinear_devox.hpp"
|
||||
#include "sampling/sampling.hpp"
|
||||
#include "voxelization/vox.hpp"
|
||||
|
||||
PYBIND11_MODULE(_pvcnn_backend, m) {
|
||||
m.def("gather_features_forward", &gather_features_forward,
|
||||
"Gather Centers' Features forward (CUDA)");
|
||||
m.def("gather_features_backward", &gather_features_backward,
|
||||
"Gather Centers' Features backward (CUDA)");
|
||||
m.def("furthest_point_sampling", &furthest_point_sampling_forward,
|
||||
"Furthest Point Sampling (CUDA)");
|
||||
m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)");
|
||||
m.def("grouping_forward", &grouping_forward,
|
||||
"Grouping Features forward (CUDA)");
|
||||
m.def("grouping_backward", &grouping_backward,
|
||||
"Grouping Features backward (CUDA)");
|
||||
m.def("three_nearest_neighbors_interpolate_forward",
|
||||
&three_nearest_neighbors_interpolate_forward,
|
||||
"3 Nearest Neighbors Interpolate forward (CUDA)");
|
||||
m.def("three_nearest_neighbors_interpolate_backward",
|
||||
&three_nearest_neighbors_interpolate_backward,
|
||||
"3 Nearest Neighbors Interpolate backward (CUDA)");
|
||||
|
||||
m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward,
|
||||
"Trilinear Devoxelization forward (CUDA)");
|
||||
m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward,
|
||||
"Trilinear Devoxelization backward (CUDA)");
|
||||
m.def("avg_voxelize_forward", &avg_voxelize_forward,
|
||||
"Voxelization forward with average pooling (CUDA)");
|
||||
m.def("avg_voxelize_backward", &avg_voxelize_backward,
|
||||
"Voxelization backward (CUDA)");
|
||||
}
|
39
third_party/pvcnn/functional/src/cuda_utils.cuh
vendored
Normal file
39
third_party/pvcnn/functional/src/cuda_utils.cuh
vendored
Normal file
|
@ -0,0 +1,39 @@
|
|||
#ifndef _CUDA_UTILS_H
|
||||
#define _CUDA_UTILS_H
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cmath>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#define MAXIMUM_THREADS 512
|
||||
|
||||
inline int optimal_num_threads(int work_size) {
|
||||
const int pow_2 = std::log2(static_cast<double>(work_size));
|
||||
return max(min(1 << pow_2, MAXIMUM_THREADS), 1);
|
||||
}
|
||||
|
||||
inline dim3 optimal_block_config(int x, int y) {
|
||||
const int x_threads = optimal_num_threads(x);
|
||||
const int y_threads =
|
||||
max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1);
|
||||
dim3 block_config(x_threads, y_threads, 1);
|
||||
return block_config;
|
||||
}
|
||||
|
||||
#define CUDA_CHECK_ERRORS() \
|
||||
{ \
|
||||
cudaError_t err = cudaGetLastError(); \
|
||||
if (cudaSuccess != err) { \
|
||||
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
|
||||
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
|
||||
__FILE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
}
|
||||
|
||||
#endif
|
44
third_party/pvcnn/functional/src/grouping/grouping.cpp
vendored
Normal file
44
third_party/pvcnn/functional/src/grouping/grouping.cpp
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
#include "grouping.hpp"
|
||||
#include "grouping.cuh"
|
||||
|
||||
#include "../utils.hpp"
|
||||
|
||||
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) {
|
||||
CHECK_CUDA(features);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_IS_FLOAT(features);
|
||||
CHECK_IS_INT(indices);
|
||||
|
||||
int b = features.size(0);
|
||||
int c = features.size(1);
|
||||
int n = features.size(2);
|
||||
int m = indices.size(1);
|
||||
int u = indices.size(2);
|
||||
at::Tensor output = torch::zeros(
|
||||
{b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float));
|
||||
grouping(b, c, n, m, u, features.data_ptr<float>(), indices.data_ptr<int>(),
|
||||
output.data_ptr<float>());
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
|
||||
const int n) {
|
||||
CHECK_CUDA(grad_y);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CONTIGUOUS(grad_y);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_IS_FLOAT(grad_y);
|
||||
CHECK_IS_INT(indices);
|
||||
|
||||
int b = grad_y.size(0);
|
||||
int c = grad_y.size(1);
|
||||
int m = indices.size(1);
|
||||
int u = indices.size(2);
|
||||
at::Tensor grad_x = torch::zeros(
|
||||
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
||||
grouping_grad(b, c, n, m, u, grad_y.data_ptr<float>(),
|
||||
indices.data_ptr<int>(), grad_x.data_ptr<float>());
|
||||
return grad_x;
|
||||
}
|
85
third_party/pvcnn/functional/src/grouping/grouping.cu
vendored
Normal file
85
third_party/pvcnn/functional/src/grouping/grouping.cu
vendored
Normal file
|
@ -0,0 +1,85 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "../cuda_utils.cuh"
|
||||
|
||||
/*
|
||||
Function: grouping features of neighbors (forward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channles of features
|
||||
n : number of points in point clouds
|
||||
m : number of query centers
|
||||
u : maximum number of neighbors
|
||||
features: points' features, FloatTensor[b, c, n]
|
||||
indices : neighbor indices in points, IntTensor[b, m, u]
|
||||
out : gathered features, FloatTensor[b, c, m, u]
|
||||
*/
|
||||
__global__ void grouping_kernel(int b, int c, int n, int m, int u,
|
||||
const float *__restrict__ features,
|
||||
const int *__restrict__ indices,
|
||||
float *__restrict__ out) {
|
||||
int batch_index = blockIdx.x;
|
||||
features += batch_index * n * c;
|
||||
indices += batch_index * m * u;
|
||||
out += batch_index * m * u * c;
|
||||
|
||||
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
const int stride = blockDim.y * blockDim.x;
|
||||
for (int i = index; i < c * m; i += stride) {
|
||||
const int l = i / m;
|
||||
const int j = i % m;
|
||||
for (int k = 0; k < u; ++k) {
|
||||
out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void grouping(int b, int c, int n, int m, int u, const float *features,
|
||||
const int *indices, float *out) {
|
||||
grouping_kernel<<<b, optimal_block_config(m, c), 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(b, c, n, m, u, features,
|
||||
indices, out);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
||||
|
||||
/*
|
||||
Function: grouping features of neighbors (backward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channles of features
|
||||
n : number of points in point clouds
|
||||
m : number of query centers
|
||||
u : maximum number of neighbors
|
||||
grad_y : grad of gathered features, FloatTensor[b, c, m, u]
|
||||
indices : neighbor indices in points, IntTensor[b, m, u]
|
||||
grad_x: grad of points' features, FloatTensor[b, c, n]
|
||||
*/
|
||||
__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u,
|
||||
const float *__restrict__ grad_y,
|
||||
const int *__restrict__ indices,
|
||||
float *__restrict__ grad_x) {
|
||||
int batch_index = blockIdx.x;
|
||||
grad_y += batch_index * m * u * c;
|
||||
indices += batch_index * m * u;
|
||||
grad_x += batch_index * n * c;
|
||||
|
||||
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
const int stride = blockDim.y * blockDim.x;
|
||||
for (int i = index; i < c * m; i += stride) {
|
||||
const int l = i / m;
|
||||
const int j = i % m;
|
||||
for (int k = 0; k < u; ++k) {
|
||||
atomicAdd(grad_x + l * n + indices[j * u + k],
|
||||
grad_y[(l * m + j) * u + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
|
||||
const int *indices, float *grad_x) {
|
||||
grouping_grad_kernel<<<b, optimal_block_config(m, c), 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, c, n, m, u, grad_y, indices, grad_x);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
9
third_party/pvcnn/functional/src/grouping/grouping.cuh
vendored
Normal file
9
third_party/pvcnn/functional/src/grouping/grouping.cuh
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
#ifndef _GROUPING_CUH
|
||||
#define _GROUPING_CUH
|
||||
|
||||
void grouping(int b, int c, int n, int m, int u, const float *features,
|
||||
const int *indices, float *out);
|
||||
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
|
||||
const int *indices, float *grad_x);
|
||||
|
||||
#endif
|
10
third_party/pvcnn/functional/src/grouping/grouping.hpp
vendored
Normal file
10
third_party/pvcnn/functional/src/grouping/grouping.hpp
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
#ifndef _GROUPING_HPP
|
||||
#define _GROUPING_HPP
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices);
|
||||
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
|
||||
const int n);
|
||||
|
||||
#endif
|
65
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cpp
vendored
Normal file
65
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cpp
vendored
Normal file
|
@ -0,0 +1,65 @@
|
|||
#include "neighbor_interpolate.hpp"
|
||||
#include "neighbor_interpolate.cuh"
|
||||
|
||||
#include "../utils.hpp"
|
||||
|
||||
std::vector<at::Tensor>
|
||||
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
|
||||
at::Tensor centers_coords,
|
||||
at::Tensor centers_features) {
|
||||
CHECK_CUDA(points_coords);
|
||||
CHECK_CUDA(centers_coords);
|
||||
CHECK_CUDA(centers_features);
|
||||
CHECK_CONTIGUOUS(points_coords);
|
||||
CHECK_CONTIGUOUS(centers_coords);
|
||||
CHECK_CONTIGUOUS(centers_features);
|
||||
CHECK_IS_FLOAT(points_coords);
|
||||
CHECK_IS_FLOAT(centers_coords);
|
||||
CHECK_IS_FLOAT(centers_features);
|
||||
|
||||
int b = centers_features.size(0);
|
||||
int c = centers_features.size(1);
|
||||
int m = centers_features.size(2);
|
||||
int n = points_coords.size(2);
|
||||
|
||||
at::Tensor indices = torch::zeros(
|
||||
{b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int));
|
||||
at::Tensor weights = torch::zeros(
|
||||
{b, 3, n},
|
||||
at::device(points_coords.device()).dtype(at::ScalarType::Float));
|
||||
at::Tensor output = torch::zeros(
|
||||
{b, c, n},
|
||||
at::device(centers_features.device()).dtype(at::ScalarType::Float));
|
||||
|
||||
three_nearest_neighbors_interpolate(
|
||||
b, c, m, n, points_coords.data_ptr<float>(),
|
||||
centers_coords.data_ptr<float>(), centers_features.data_ptr<float>(),
|
||||
indices.data_ptr<int>(), weights.data_ptr<float>(),
|
||||
output.data_ptr<float>());
|
||||
return {output, indices, weights};
|
||||
}
|
||||
|
||||
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
|
||||
at::Tensor indices,
|
||||
at::Tensor weights,
|
||||
const int m) {
|
||||
CHECK_CUDA(grad_y);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CUDA(weights);
|
||||
CHECK_CONTIGUOUS(grad_y);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_CONTIGUOUS(weights);
|
||||
CHECK_IS_FLOAT(grad_y);
|
||||
CHECK_IS_INT(indices);
|
||||
CHECK_IS_FLOAT(weights);
|
||||
|
||||
int b = grad_y.size(0);
|
||||
int c = grad_y.size(1);
|
||||
int n = grad_y.size(2);
|
||||
at::Tensor grad_x = torch::zeros(
|
||||
{b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
||||
three_nearest_neighbors_interpolate_grad(
|
||||
b, c, n, m, grad_y.data_ptr<float>(), indices.data_ptr<int>(),
|
||||
weights.data_ptr<float>(), grad_x.data_ptr<float>());
|
||||
return grad_x;
|
||||
}
|
181
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cu
vendored
Normal file
181
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cu
vendored
Normal file
|
@ -0,0 +1,181 @@
|
|||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "../cuda_utils.cuh"
|
||||
|
||||
/*
|
||||
Function: three nearest neighbors
|
||||
Args:
|
||||
b : batch size
|
||||
n : number of points in point clouds
|
||||
m : number of query centers
|
||||
points_coords : coordinates of points, FloatTensor[b, 3, n]
|
||||
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
|
||||
weights : weights of nearest 3 centers to the point,
|
||||
FloatTensor[b, 3, n]
|
||||
indices : indices of nearest 3 centers to the point,
|
||||
IntTensor[b, 3, n]
|
||||
*/
|
||||
__global__ void three_nearest_neighbors_kernel(
|
||||
int b, int n, int m, const float *__restrict__ points_coords,
|
||||
const float *__restrict__ centers_coords, float *__restrict__ weights,
|
||||
int *__restrict__ indices) {
|
||||
int batch_index = blockIdx.x;
|
||||
int index = threadIdx.x;
|
||||
int stride = blockDim.x;
|
||||
points_coords += batch_index * 3 * n;
|
||||
weights += batch_index * 3 * n;
|
||||
indices += batch_index * 3 * n;
|
||||
centers_coords += batch_index * 3 * m;
|
||||
|
||||
for (int j = index; j < n; j += stride) {
|
||||
float ux = points_coords[j];
|
||||
float uy = points_coords[j + n];
|
||||
float uz = points_coords[j + n + n];
|
||||
|
||||
double best0 = 1e40, best1 = 1e40, best2 = 1e40;
|
||||
int besti0 = 0, besti1 = 0, besti2 = 0;
|
||||
for (int k = 0; k < m; ++k) {
|
||||
float x = centers_coords[k];
|
||||
float y = centers_coords[k + m];
|
||||
float z = centers_coords[k + m + m];
|
||||
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
|
||||
if (d < best2) {
|
||||
best2 = d;
|
||||
besti2 = k;
|
||||
if (d < best1) {
|
||||
best2 = best1;
|
||||
besti2 = besti1;
|
||||
best1 = d;
|
||||
besti1 = k;
|
||||
if (d < best0) {
|
||||
best1 = best0;
|
||||
besti1 = besti0;
|
||||
best0 = d;
|
||||
besti0 = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best0 = max(min(1e10f, best0), 1e-10f);
|
||||
best1 = max(min(1e10f, best1), 1e-10f);
|
||||
best2 = max(min(1e10f, best2), 1e-10f);
|
||||
float d0d1 = best0 * best1;
|
||||
float d0d2 = best0 * best2;
|
||||
float d1d2 = best1 * best2;
|
||||
float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2);
|
||||
weights[j] = d1d2 * d0d1d2;
|
||||
indices[j] = besti0;
|
||||
weights[j + n] = d0d2 * d0d1d2;
|
||||
indices[j + n] = besti1;
|
||||
weights[j + n + n] = d0d1 * d0d1d2;
|
||||
indices[j + n + n] = besti2;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Function: interpolate three nearest neighbors (forward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channels of features
|
||||
m : number of query centers
|
||||
n : number of points in point clouds
|
||||
centers_features: features of centers, FloatTensor[b, c, m]
|
||||
indices : indices of nearest 3 centers to the point,
|
||||
IntTensor[b, 3, n]
|
||||
weights : weights for interpolation, FloatTensor[b, 3, n]
|
||||
out : features of points, FloatTensor[b, c, n]
|
||||
*/
|
||||
__global__ void three_nearest_neighbors_interpolate_kernel(
|
||||
int b, int c, int m, int n, const float *__restrict__ centers_features,
|
||||
const int *__restrict__ indices, const float *__restrict__ weights,
|
||||
float *__restrict__ out) {
|
||||
int batch_index = blockIdx.x;
|
||||
centers_features += batch_index * m * c;
|
||||
indices += batch_index * n * 3;
|
||||
weights += batch_index * n * 3;
|
||||
out += batch_index * n * c;
|
||||
|
||||
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
const int stride = blockDim.y * blockDim.x;
|
||||
for (int i = index; i < c * n; i += stride) {
|
||||
const int l = i / n;
|
||||
const int j = i % n;
|
||||
float w1 = weights[j];
|
||||
float w2 = weights[j + n];
|
||||
float w3 = weights[j + n + n];
|
||||
int i1 = indices[j];
|
||||
int i2 = indices[j + n];
|
||||
int i3 = indices[j + n + n];
|
||||
|
||||
out[i] = centers_features[l * m + i1] * w1 +
|
||||
centers_features[l * m + i2] * w2 +
|
||||
centers_features[l * m + i3] * w3;
|
||||
}
|
||||
}
|
||||
|
||||
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
|
||||
const float *points_coords,
|
||||
const float *centers_coords,
|
||||
const float *centers_features,
|
||||
int *indices, float *weights,
|
||||
float *out) {
|
||||
three_nearest_neighbors_kernel<<<b, optimal_num_threads(n), 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, n, m, points_coords, centers_coords, weights, indices);
|
||||
three_nearest_neighbors_interpolate_kernel<<<
|
||||
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, c, m, n, centers_features, indices, weights, out);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
||||
|
||||
/*
|
||||
Function: interpolate three nearest neighbors (backward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channels of features
|
||||
m : number of query centers
|
||||
n : number of points in point clouds
|
||||
grad_y : grad of features of points, FloatTensor[b, c, n]
|
||||
indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n]
|
||||
weights : weights for interpolation, FloatTensor[b, 3, n]
|
||||
grad_x : grad of features of centers, FloatTensor[b, c, m]
|
||||
*/
|
||||
__global__ void three_nearest_neighbors_interpolate_grad_kernel(
|
||||
int b, int c, int n, int m, const float *__restrict__ grad_y,
|
||||
const int *__restrict__ indices, const float *__restrict__ weights,
|
||||
float *__restrict__ grad_x) {
|
||||
int batch_index = blockIdx.x;
|
||||
grad_y += batch_index * n * c;
|
||||
indices += batch_index * n * 3;
|
||||
weights += batch_index * n * 3;
|
||||
grad_x += batch_index * m * c;
|
||||
|
||||
const int index = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
const int stride = blockDim.y * blockDim.x;
|
||||
for (int i = index; i < c * n; i += stride) {
|
||||
const int l = i / n;
|
||||
const int j = i % n;
|
||||
float w1 = weights[j];
|
||||
float w2 = weights[j + n];
|
||||
float w3 = weights[j + n + n];
|
||||
int i1 = indices[j];
|
||||
int i2 = indices[j + n];
|
||||
int i3 = indices[j + n + n];
|
||||
atomicAdd(grad_x + l * m + i1, grad_y[i] * w1);
|
||||
atomicAdd(grad_x + l * m + i2, grad_y[i] * w2);
|
||||
atomicAdd(grad_x + l * m + i3, grad_y[i] * w3);
|
||||
}
|
||||
}
|
||||
|
||||
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
|
||||
const float *grad_y,
|
||||
const int *indices,
|
||||
const float *weights,
|
||||
float *grad_x) {
|
||||
three_nearest_neighbors_interpolate_grad_kernel<<<
|
||||
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, c, n, m, grad_y, indices, weights, grad_x);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
16
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cuh
vendored
Normal file
16
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.cuh
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
#ifndef _NEIGHBOR_INTERPOLATE_CUH
|
||||
#define _NEIGHBOR_INTERPOLATE_CUH
|
||||
|
||||
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
|
||||
const float *points_coords,
|
||||
const float *centers_coords,
|
||||
const float *centers_features,
|
||||
int *indices, float *weights,
|
||||
float *out);
|
||||
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
|
||||
const float *grad_y,
|
||||
const int *indices,
|
||||
const float *weights,
|
||||
float *grad_x);
|
||||
|
||||
#endif
|
16
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.hpp
vendored
Normal file
16
third_party/pvcnn/functional/src/interpolate/neighbor_interpolate.hpp
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
#ifndef _NEIGHBOR_INTERPOLATE_HPP
|
||||
#define _NEIGHBOR_INTERPOLATE_HPP
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor>
|
||||
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
|
||||
at::Tensor centers_coords,
|
||||
at::Tensor centers_features);
|
||||
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
|
||||
at::Tensor indices,
|
||||
at::Tensor weights,
|
||||
const int m);
|
||||
|
||||
#endif
|
91
third_party/pvcnn/functional/src/interpolate/trilinear_devox.cpp
vendored
Normal file
91
third_party/pvcnn/functional/src/interpolate/trilinear_devox.cpp
vendored
Normal file
|
@ -0,0 +1,91 @@
|
|||
#include "trilinear_devox.hpp"
|
||||
#include "trilinear_devox.cuh"
|
||||
|
||||
#include "../utils.hpp"
|
||||
|
||||
/*
|
||||
Function: trilinear devoxelization (forward)
|
||||
Args:
|
||||
r : voxel resolution
|
||||
trainig : whether is training mode
|
||||
coords : the coordinates of points, FloatTensor[b, 3, n]
|
||||
features : features, FloatTensor[b, c, s], s = r ** 3
|
||||
Return:
|
||||
outs : outputs, FloatTensor[b, c, n]
|
||||
inds : the voxel coordinates of point cube, IntTensor[b, 8, n]
|
||||
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
|
||||
*/
|
||||
std::vector<at::Tensor>
|
||||
trilinear_devoxelize_forward(const int r, const bool is_training,
|
||||
const at::Tensor coords,
|
||||
const at::Tensor features) {
|
||||
CHECK_CUDA(features);
|
||||
CHECK_CUDA(coords);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(coords);
|
||||
CHECK_IS_FLOAT(features);
|
||||
CHECK_IS_FLOAT(coords);
|
||||
|
||||
int b = features.size(0);
|
||||
int c = features.size(1);
|
||||
int n = coords.size(2);
|
||||
int r2 = r * r;
|
||||
int r3 = r2 * r;
|
||||
at::Tensor outs = torch::zeros(
|
||||
{b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float));
|
||||
if (is_training) {
|
||||
at::Tensor inds = torch::zeros(
|
||||
{b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int));
|
||||
at::Tensor wgts = torch::zeros(
|
||||
{b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float));
|
||||
trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr<float>(),
|
||||
features.data_ptr<float>(), inds.data_ptr<int>(),
|
||||
wgts.data_ptr<float>(), outs.data_ptr<float>());
|
||||
return {outs, inds, wgts};
|
||||
} else {
|
||||
at::Tensor inds = torch::zeros(
|
||||
{1}, at::device(features.device()).dtype(at::ScalarType::Int));
|
||||
at::Tensor wgts = torch::zeros(
|
||||
{1}, at::device(features.device()).dtype(at::ScalarType::Float));
|
||||
trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr<float>(),
|
||||
features.data_ptr<float>(), inds.data_ptr<int>(),
|
||||
wgts.data_ptr<float>(), outs.data_ptr<float>());
|
||||
return {outs, inds, wgts};
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Function: trilinear devoxelization (backward)
|
||||
Args:
|
||||
grad_y : grad outputs, FloatTensor[b, c, n]
|
||||
indices : the voxel coordinates of point cube, IntTensor[b, 8, n]
|
||||
weights : weight for trilinear interpolation, FloatTensor[b, 8, n]
|
||||
r : voxel resolution
|
||||
Return:
|
||||
grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3
|
||||
*/
|
||||
at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
|
||||
const at::Tensor indices,
|
||||
const at::Tensor weights,
|
||||
const int r) {
|
||||
CHECK_CUDA(grad_y);
|
||||
CHECK_CUDA(weights);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CONTIGUOUS(grad_y);
|
||||
CHECK_CONTIGUOUS(weights);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_IS_FLOAT(grad_y);
|
||||
CHECK_IS_FLOAT(weights);
|
||||
CHECK_IS_INT(indices);
|
||||
|
||||
int b = grad_y.size(0);
|
||||
int c = grad_y.size(1);
|
||||
int n = grad_y.size(2);
|
||||
int r3 = r * r * r;
|
||||
at::Tensor grad_x = torch::zeros(
|
||||
{b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
||||
trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr<int>(),
|
||||
weights.data_ptr<float>(), grad_y.data_ptr<float>(),
|
||||
grad_x.data_ptr<float>());
|
||||
return grad_x;
|
||||
}
|
178
third_party/pvcnn/functional/src/interpolate/trilinear_devox.cu
vendored
Normal file
178
third_party/pvcnn/functional/src/interpolate/trilinear_devox.cu
vendored
Normal file
|
@ -0,0 +1,178 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "../cuda_utils.cuh"
|
||||
|
||||
/*
|
||||
Function: trilinear devoxlization (forward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channels
|
||||
n : number of points
|
||||
r : voxel resolution
|
||||
r2 : r ** 2
|
||||
r3 : r ** 3
|
||||
coords : the coordinates of points, FloatTensor[b, 3, n]
|
||||
feat : features, FloatTensor[b, c, r3]
|
||||
inds : the voxel indices of point cube, IntTensor[b, 8, n]
|
||||
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
|
||||
outs : outputs, FloatTensor[b, c, n]
|
||||
*/
|
||||
__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2,
|
||||
int r3, bool is_training,
|
||||
const float *__restrict__ coords,
|
||||
const float *__restrict__ feat,
|
||||
int *__restrict__ inds,
|
||||
float *__restrict__ wgts,
|
||||
float *__restrict__ outs) {
|
||||
int batch_index = blockIdx.x;
|
||||
int stride = blockDim.x;
|
||||
int index = threadIdx.x;
|
||||
coords += batch_index * n * 3;
|
||||
inds += batch_index * n * 8;
|
||||
wgts += batch_index * n * 8;
|
||||
feat += batch_index * c * r3;
|
||||
outs += batch_index * c * n;
|
||||
|
||||
for (int i = index; i < n; i += stride) {
|
||||
float x = coords[i];
|
||||
float y = coords[i + n];
|
||||
float z = coords[i + n + n];
|
||||
float x_lo_f = floorf(x);
|
||||
float y_lo_f = floorf(y);
|
||||
float z_lo_f = floorf(z);
|
||||
|
||||
float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f)
|
||||
float y_d_1 = y - y_lo_f;
|
||||
float z_d_1 = z - z_lo_f;
|
||||
float x_d_0 = 1.0f - x_d_1;
|
||||
float y_d_0 = 1.0f - y_d_1;
|
||||
float z_d_0 = 1.0f - z_d_1;
|
||||
|
||||
float wgt000 = x_d_0 * y_d_0 * z_d_0;
|
||||
float wgt001 = x_d_0 * y_d_0 * z_d_1;
|
||||
float wgt010 = x_d_0 * y_d_1 * z_d_0;
|
||||
float wgt011 = x_d_0 * y_d_1 * z_d_1;
|
||||
float wgt100 = x_d_1 * y_d_0 * z_d_0;
|
||||
float wgt101 = x_d_1 * y_d_0 * z_d_1;
|
||||
float wgt110 = x_d_1 * y_d_1 * z_d_0;
|
||||
float wgt111 = x_d_1 * y_d_1 * z_d_1;
|
||||
|
||||
int x_lo = static_cast<int>(x_lo_f);
|
||||
int y_lo = static_cast<int>(y_lo_f);
|
||||
int z_lo = static_cast<int>(z_lo_f);
|
||||
int x_hi = (x_d_1 > 0) ? -1 : 0;
|
||||
int y_hi = (y_d_1 > 0) ? -1 : 0;
|
||||
int z_hi = (z_d_1 > 0) ? 1 : 0;
|
||||
|
||||
int idx000 = x_lo * r2 + y_lo * r + z_lo;
|
||||
int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi;
|
||||
int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo;
|
||||
int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi;
|
||||
int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo;
|
||||
int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi;
|
||||
int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo;
|
||||
int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi;
|
||||
|
||||
if (is_training) {
|
||||
wgts[i] = wgt000;
|
||||
wgts[i + n] = wgt001;
|
||||
wgts[i + n * 2] = wgt010;
|
||||
wgts[i + n * 3] = wgt011;
|
||||
wgts[i + n * 4] = wgt100;
|
||||
wgts[i + n * 5] = wgt101;
|
||||
wgts[i + n * 6] = wgt110;
|
||||
wgts[i + n * 7] = wgt111;
|
||||
inds[i] = idx000;
|
||||
inds[i + n] = idx001;
|
||||
inds[i + n * 2] = idx010;
|
||||
inds[i + n * 3] = idx011;
|
||||
inds[i + n * 4] = idx100;
|
||||
inds[i + n * 5] = idx101;
|
||||
inds[i + n * 6] = idx110;
|
||||
inds[i + n * 7] = idx111;
|
||||
}
|
||||
|
||||
for (int j = 0; j < c; j++) {
|
||||
int jr3 = j * r3;
|
||||
outs[j * n + i] =
|
||||
wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] +
|
||||
wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] +
|
||||
wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] +
|
||||
wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Function: trilinear devoxlization (backward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channels
|
||||
n : number of points
|
||||
r3 : voxel cube size = voxel resolution ** 3
|
||||
inds : the voxel indices of point cube, IntTensor[b, 8, n]
|
||||
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
|
||||
grad_y : grad outputs, FloatTensor[b, c, n]
|
||||
grad_x : grad inputs, FloatTensor[b, c, r3]
|
||||
*/
|
||||
__global__ void trilinear_devoxelize_grad_kernel(
|
||||
int b, int c, int n, int r3, const int *__restrict__ inds,
|
||||
const float *__restrict__ wgts, const float *__restrict__ grad_y,
|
||||
float *__restrict__ grad_x) {
|
||||
int batch_index = blockIdx.x;
|
||||
int stride = blockDim.x;
|
||||
int index = threadIdx.x;
|
||||
inds += batch_index * n * 8;
|
||||
wgts += batch_index * n * 8;
|
||||
grad_x += batch_index * c * r3;
|
||||
grad_y += batch_index * c * n;
|
||||
|
||||
for (int i = index; i < n; i += stride) {
|
||||
int idx000 = inds[i];
|
||||
int idx001 = inds[i + n];
|
||||
int idx010 = inds[i + n * 2];
|
||||
int idx011 = inds[i + n * 3];
|
||||
int idx100 = inds[i + n * 4];
|
||||
int idx101 = inds[i + n * 5];
|
||||
int idx110 = inds[i + n * 6];
|
||||
int idx111 = inds[i + n * 7];
|
||||
float wgt000 = wgts[i];
|
||||
float wgt001 = wgts[i + n];
|
||||
float wgt010 = wgts[i + n * 2];
|
||||
float wgt011 = wgts[i + n * 3];
|
||||
float wgt100 = wgts[i + n * 4];
|
||||
float wgt101 = wgts[i + n * 5];
|
||||
float wgt110 = wgts[i + n * 6];
|
||||
float wgt111 = wgts[i + n * 7];
|
||||
|
||||
for (int j = 0; j < c; j++) {
|
||||
int jr3 = j * r3;
|
||||
float g = grad_y[j * n + i];
|
||||
atomicAdd(grad_x + jr3 + idx000, wgt000 * g);
|
||||
atomicAdd(grad_x + jr3 + idx001, wgt001 * g);
|
||||
atomicAdd(grad_x + jr3 + idx010, wgt010 * g);
|
||||
atomicAdd(grad_x + jr3 + idx011, wgt011 * g);
|
||||
atomicAdd(grad_x + jr3 + idx100, wgt100 * g);
|
||||
atomicAdd(grad_x + jr3 + idx101, wgt101 * g);
|
||||
atomicAdd(grad_x + jr3 + idx110, wgt110 * g);
|
||||
atomicAdd(grad_x + jr3 + idx111, wgt111 * g);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
|
||||
bool training, const float *coords, const float *feat,
|
||||
int *inds, float *wgts, float *outs) {
|
||||
trilinear_devoxelize_kernel<<<b, optimal_num_threads(n)>>>(
|
||||
b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
||||
|
||||
void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
|
||||
const float *wgts, const float *grad_y,
|
||||
float *grad_x) {
|
||||
trilinear_devoxelize_grad_kernel<<<b, optimal_num_threads(n)>>>(
|
||||
b, c, n, r3, inds, wgts, grad_y, grad_x);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
13
third_party/pvcnn/functional/src/interpolate/trilinear_devox.cuh
vendored
Normal file
13
third_party/pvcnn/functional/src/interpolate/trilinear_devox.cuh
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
#ifndef _TRILINEAR_DEVOX_CUH
|
||||
#define _TRILINEAR_DEVOX_CUH
|
||||
|
||||
// CUDA function declarations
|
||||
void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
|
||||
bool is_training, const float *coords,
|
||||
const float *feat, int *inds, float *wgts,
|
||||
float *outs);
|
||||
void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
|
||||
const float *wgts, const float *grad_y,
|
||||
float *grad_x);
|
||||
|
||||
#endif
|
16
third_party/pvcnn/functional/src/interpolate/trilinear_devox.hpp
vendored
Normal file
16
third_party/pvcnn/functional/src/interpolate/trilinear_devox.hpp
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
#ifndef _TRILINEAR_DEVOX_HPP
|
||||
#define _TRILINEAR_DEVOX_HPP
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> trilinear_devoxelize_forward(const int r,
|
||||
const bool is_training,
|
||||
const at::Tensor coords,
|
||||
const at::Tensor features);
|
||||
|
||||
at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
|
||||
const at::Tensor indices,
|
||||
const at::Tensor weights, const int r);
|
||||
|
||||
#endif
|
58
third_party/pvcnn/functional/src/sampling/sampling.cpp
vendored
Normal file
58
third_party/pvcnn/functional/src/sampling/sampling.cpp
vendored
Normal file
|
@ -0,0 +1,58 @@
|
|||
#include "sampling.hpp"
|
||||
#include "sampling.cuh"
|
||||
|
||||
#include "../utils.hpp"
|
||||
|
||||
at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) {
|
||||
CHECK_CUDA(features);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_IS_FLOAT(features);
|
||||
CHECK_IS_INT(indices);
|
||||
|
||||
int b = features.size(0);
|
||||
int c = features.size(1);
|
||||
int n = features.size(2);
|
||||
int m = indices.size(1);
|
||||
at::Tensor output = torch::zeros(
|
||||
{b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float));
|
||||
gather_features(b, c, n, m, features.data_ptr<float>(),
|
||||
indices.data_ptr<int>(), output.data_ptr<float>());
|
||||
return output;
|
||||
}
|
||||
|
||||
at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
|
||||
const int n) {
|
||||
CHECK_CUDA(grad_y);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CONTIGUOUS(grad_y);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_IS_FLOAT(grad_y);
|
||||
CHECK_IS_INT(indices);
|
||||
|
||||
int b = grad_y.size(0);
|
||||
int c = grad_y.size(1);
|
||||
at::Tensor grad_x = torch::zeros(
|
||||
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
||||
gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr<float>(),
|
||||
indices.data_ptr<int>(), grad_x.data_ptr<float>());
|
||||
return grad_x;
|
||||
}
|
||||
|
||||
at::Tensor furthest_point_sampling_forward(at::Tensor coords,
|
||||
const int num_samples) {
|
||||
CHECK_CUDA(coords);
|
||||
CHECK_CONTIGUOUS(coords);
|
||||
CHECK_IS_FLOAT(coords);
|
||||
|
||||
int b = coords.size(0);
|
||||
int n = coords.size(2);
|
||||
at::Tensor indices = torch::zeros(
|
||||
{b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int));
|
||||
at::Tensor distances = torch::full(
|
||||
{b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float));
|
||||
furthest_point_sampling(b, n, num_samples, coords.data_ptr<float>(),
|
||||
distances.data_ptr<float>(), indices.data_ptr<int>());
|
||||
return indices;
|
||||
}
|
174
third_party/pvcnn/functional/src/sampling/sampling.cu
vendored
Normal file
174
third_party/pvcnn/functional/src/sampling/sampling.cu
vendored
Normal file
|
@ -0,0 +1,174 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "../cuda_utils.cuh"
|
||||
|
||||
/*
|
||||
Function: gather centers' features (forward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channles of features
|
||||
n : number of points in point clouds
|
||||
m : number of query/sampled centers
|
||||
features: points' features, FloatTensor[b, c, n]
|
||||
indices : centers' indices in points, IntTensor[b, m]
|
||||
out : gathered features, FloatTensor[b, c, m]
|
||||
*/
|
||||
__global__ void gather_features_kernel(int b, int c, int n, int m,
|
||||
const float *__restrict__ features,
|
||||
const int *__restrict__ indices,
|
||||
float *__restrict__ out) {
|
||||
int batch_index = blockIdx.x;
|
||||
int channel_index = blockIdx.y;
|
||||
int temp_index = batch_index * c + channel_index;
|
||||
features += temp_index * n;
|
||||
indices += batch_index * m;
|
||||
out += temp_index * m;
|
||||
|
||||
for (int j = threadIdx.x; j < m; j += blockDim.x) {
|
||||
out[j] = features[indices[j]];
|
||||
}
|
||||
}
|
||||
|
||||
void gather_features(int b, int c, int n, int m, const float *features,
|
||||
const int *indices, float *out) {
|
||||
gather_features_kernel<<<dim3(b, c, 1), optimal_num_threads(m), 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, c, n, m, features, indices, out);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
||||
|
||||
/*
|
||||
Function: gather centers' features (backward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channles of features
|
||||
n : number of points in point clouds
|
||||
m : number of query/sampled centers
|
||||
grad_y : grad of gathered features, FloatTensor[b, c, m]
|
||||
indices : centers' indices in points, IntTensor[b, m]
|
||||
grad_x : grad of points' features, FloatTensor[b, c, n]
|
||||
*/
|
||||
__global__ void gather_features_grad_kernel(int b, int c, int n, int m,
|
||||
const float *__restrict__ grad_y,
|
||||
const int *__restrict__ indices,
|
||||
float *__restrict__ grad_x) {
|
||||
int batch_index = blockIdx.x;
|
||||
int channel_index = blockIdx.y;
|
||||
int temp_index = batch_index * c + channel_index;
|
||||
grad_y += temp_index * m;
|
||||
indices += batch_index * m;
|
||||
grad_x += temp_index * n;
|
||||
|
||||
for (int j = threadIdx.x; j < m; j += blockDim.x) {
|
||||
atomicAdd(grad_x + indices[j], grad_y[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
|
||||
const int *indices, float *grad_x) {
|
||||
gather_features_grad_kernel<<<dim3(b, c, 1), optimal_num_threads(m), 0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
b, c, n, m, grad_y, indices, grad_x);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
||||
|
||||
/*
|
||||
Function: furthest point sampling
|
||||
Args:
|
||||
b : batch size
|
||||
n : number of points in point clouds
|
||||
m : number of query/sampled centers
|
||||
coords : points' coords, FloatTensor[b, 3, n]
|
||||
distances : minimum distance of a point to the set, IntTensor[b, n]
|
||||
indices : sampled centers' indices in points, IntTensor[b, m]
|
||||
*/
|
||||
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
|
||||
const float *__restrict__ coords,
|
||||
float *__restrict__ distances,
|
||||
int *__restrict__ indices) {
|
||||
if (m <= 0)
|
||||
return;
|
||||
int batch_index = blockIdx.x;
|
||||
coords += batch_index * n * 3;
|
||||
distances += batch_index * n;
|
||||
indices += batch_index * m;
|
||||
|
||||
const int BlockSize = 512;
|
||||
__shared__ float dists[BlockSize];
|
||||
__shared__ int dists_i[BlockSize];
|
||||
const int BufferSize = 3072;
|
||||
__shared__ float buf[BufferSize * 3];
|
||||
|
||||
int old = 0;
|
||||
if (threadIdx.x == 0)
|
||||
indices[0] = old;
|
||||
|
||||
for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) {
|
||||
buf[j] = coords[j];
|
||||
buf[j + BufferSize] = coords[j + n];
|
||||
buf[j + BufferSize + BufferSize] = coords[j + n + n];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int j = 1; j < m; j++) {
|
||||
int besti = 0; // best index
|
||||
float best = -1; // farthest distance
|
||||
// calculating the distance with the latest sampled point
|
||||
float x1 = coords[old];
|
||||
float y1 = coords[old + n];
|
||||
float z1 = coords[old + n + n];
|
||||
for (int k = threadIdx.x; k < n; k += blockDim.x) {
|
||||
// fetch distance at block n, thread k
|
||||
float td = distances[k];
|
||||
float x2, y2, z2;
|
||||
if (k < BufferSize) {
|
||||
x2 = buf[k];
|
||||
y2 = buf[k + BufferSize];
|
||||
z2 = buf[k + BufferSize + BufferSize];
|
||||
} else {
|
||||
x2 = coords[k];
|
||||
y2 = coords[k + n];
|
||||
z2 = coords[k + n + n];
|
||||
}
|
||||
float d =
|
||||
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
|
||||
float d2 = min(d, td);
|
||||
// update "point-to-set" distance
|
||||
if (d2 != td)
|
||||
distances[k] = d2;
|
||||
// update the farthest distance at sample step j
|
||||
if (d2 > best) {
|
||||
best = d2;
|
||||
besti = k;
|
||||
}
|
||||
}
|
||||
|
||||
dists[threadIdx.x] = best;
|
||||
dists_i[threadIdx.x] = besti;
|
||||
for (int u = 0; (1 << u) < blockDim.x; u++) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x < (blockDim.x >> (u + 1))) {
|
||||
int i1 = (threadIdx.x * 2) << u;
|
||||
int i2 = (threadIdx.x * 2 + 1) << u;
|
||||
if (dists[i1] < dists[i2]) {
|
||||
dists[i1] = dists[i2];
|
||||
dists_i[i1] = dists_i[i2];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// finish sample step j; old is the sampled index
|
||||
old = dists_i[0];
|
||||
if (threadIdx.x == 0)
|
||||
indices[j] = old;
|
||||
}
|
||||
}
|
||||
|
||||
void furthest_point_sampling(int b, int n, int m, const float *coords,
|
||||
float *distances, int *indices) {
|
||||
furthest_point_sampling_kernel<<<b, 512>>>(b, n, m, coords, distances,
|
||||
indices);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
11
third_party/pvcnn/functional/src/sampling/sampling.cuh
vendored
Normal file
11
third_party/pvcnn/functional/src/sampling/sampling.cuh
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
#ifndef _SAMPLING_CUH
|
||||
#define _SAMPLING_CUH
|
||||
|
||||
void gather_features(int b, int c, int n, int m, const float *features,
|
||||
const int *indices, float *out);
|
||||
void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
|
||||
const int *indices, float *grad_x);
|
||||
void furthest_point_sampling(int b, int n, int m, const float *coords,
|
||||
float *distances, int *indices);
|
||||
|
||||
#endif
|
12
third_party/pvcnn/functional/src/sampling/sampling.hpp
vendored
Normal file
12
third_party/pvcnn/functional/src/sampling/sampling.hpp
vendored
Normal file
|
@ -0,0 +1,12 @@
|
|||
#ifndef _SAMPLING_HPP
|
||||
#define _SAMPLING_HPP
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices);
|
||||
at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
|
||||
const int n);
|
||||
at::Tensor furthest_point_sampling_forward(at::Tensor coords,
|
||||
const int num_samples);
|
||||
|
||||
#endif
|
20
third_party/pvcnn/functional/src/utils.hpp
vendored
Normal file
20
third_party/pvcnn/functional/src/utils.hpp
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
#ifndef _UTILS_HPP
|
||||
#define _UTILS_HPP
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
|
||||
#define CHECK_IS_INT(x) \
|
||||
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
|
||||
#x " must be an int tensor")
|
||||
|
||||
#define CHECK_IS_FLOAT(x) \
|
||||
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
|
||||
#x " must be a float tensor")
|
||||
|
||||
#endif
|
76
third_party/pvcnn/functional/src/voxelization/vox.cpp
vendored
Normal file
76
third_party/pvcnn/functional/src/voxelization/vox.cpp
vendored
Normal file
|
@ -0,0 +1,76 @@
|
|||
#include "vox.hpp"
|
||||
#include "vox.cuh"
|
||||
|
||||
#include "../utils.hpp"
|
||||
|
||||
/*
|
||||
Function: average pool voxelization (forward)
|
||||
Args:
|
||||
features: features, FloatTensor[b, c, n]
|
||||
coords : coords of each point, IntTensor[b, 3, n]
|
||||
resolution : voxel resolution
|
||||
Return:
|
||||
out : outputs, FloatTensor[b, c, s], s = r ** 3
|
||||
ind : voxel index of each point, IntTensor[b, n]
|
||||
cnt : #points in each voxel index, IntTensor[b, s]
|
||||
*/
|
||||
std::vector<at::Tensor> avg_voxelize_forward(const at::Tensor features,
|
||||
const at::Tensor coords,
|
||||
const int resolution) {
|
||||
CHECK_CUDA(features);
|
||||
CHECK_CUDA(coords);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(coords);
|
||||
CHECK_IS_FLOAT(features);
|
||||
CHECK_IS_INT(coords);
|
||||
|
||||
int b = features.size(0);
|
||||
int c = features.size(1);
|
||||
int n = features.size(2);
|
||||
int r = resolution;
|
||||
int r2 = r * r;
|
||||
int r3 = r2 * r;
|
||||
at::Tensor ind = torch::zeros(
|
||||
{b, n}, at::device(features.device()).dtype(at::ScalarType::Int));
|
||||
at::Tensor out = torch::zeros(
|
||||
{b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float));
|
||||
at::Tensor cnt = torch::zeros(
|
||||
{b, r3}, at::device(features.device()).dtype(at::ScalarType::Int));
|
||||
avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr<int>(),
|
||||
features.data_ptr<float>(), ind.data_ptr<int>(),
|
||||
cnt.data_ptr<int>(), out.data_ptr<float>());
|
||||
return {out, ind, cnt};
|
||||
}
|
||||
|
||||
/*
|
||||
Function: average pool voxelization (backward)
|
||||
Args:
|
||||
grad_y : grad outputs, FloatTensor[b, c, s]
|
||||
indices: voxel index of each point, IntTensor[b, n]
|
||||
cnt : #points in each voxel index, IntTensor[b, s]
|
||||
Return:
|
||||
grad_x : grad inputs, FloatTensor[b, c, n]
|
||||
*/
|
||||
at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
|
||||
const at::Tensor indices,
|
||||
const at::Tensor cnt) {
|
||||
CHECK_CUDA(grad_y);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CUDA(cnt);
|
||||
CHECK_CONTIGUOUS(grad_y);
|
||||
CHECK_CONTIGUOUS(indices);
|
||||
CHECK_CONTIGUOUS(cnt);
|
||||
CHECK_IS_FLOAT(grad_y);
|
||||
CHECK_IS_INT(indices);
|
||||
CHECK_IS_INT(cnt);
|
||||
|
||||
int b = grad_y.size(0);
|
||||
int c = grad_y.size(1);
|
||||
int s = grad_y.size(2);
|
||||
int n = indices.size(1);
|
||||
at::Tensor grad_x = torch::zeros(
|
||||
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
|
||||
avg_voxelize_grad(b, c, n, s, indices.data_ptr<int>(), cnt.data_ptr<int>(),
|
||||
grad_y.data_ptr<float>(), grad_x.data_ptr<float>());
|
||||
return grad_x;
|
||||
}
|
126
third_party/pvcnn/functional/src/voxelization/vox.cu
vendored
Normal file
126
third_party/pvcnn/functional/src/voxelization/vox.cu
vendored
Normal file
|
@ -0,0 +1,126 @@
|
|||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "../cuda_utils.cuh"
|
||||
|
||||
/*
|
||||
Function: get how many points in each voxel grid
|
||||
Args:
|
||||
b : batch size
|
||||
n : number of points
|
||||
r : voxel resolution
|
||||
r2 : = r * r
|
||||
r3 : s, voxel cube size = r ** 3
|
||||
coords : coords of each point, IntTensor[b, 3, n]
|
||||
ind : voxel index of each point, IntTensor[b, n]
|
||||
cnt : #points in each voxel index, IntTensor[b, s]
|
||||
*/
|
||||
__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3,
|
||||
const int *__restrict__ coords,
|
||||
int *__restrict__ ind, int *cnt) {
|
||||
int batch_index = blockIdx.x;
|
||||
int stride = blockDim.x;
|
||||
int index = threadIdx.x;
|
||||
coords += batch_index * n * 3;
|
||||
ind += batch_index * n;
|
||||
cnt += batch_index * r3;
|
||||
|
||||
for (int i = index; i < n; i += stride) {
|
||||
// if (ind[i] == -1)
|
||||
// continue;
|
||||
ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n];
|
||||
atomicAdd(cnt + ind[i], 1);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Function: average pool voxelization (forward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channels
|
||||
n : number of points
|
||||
s : voxel cube size = voxel resolution ** 3
|
||||
ind : voxel index of each point, IntTensor[b, n]
|
||||
cnt : #points in each voxel index, IntTensor[b, s]
|
||||
feat: features, FloatTensor[b, c, n]
|
||||
out : outputs, FloatTensor[b, c, s]
|
||||
*/
|
||||
__global__ void avg_voxelize_kernel(int b, int c, int n, int s,
|
||||
const int *__restrict__ ind,
|
||||
const int *__restrict__ cnt,
|
||||
const float *__restrict__ feat,
|
||||
float *__restrict__ out) {
|
||||
int batch_index = blockIdx.x;
|
||||
int stride = blockDim.x;
|
||||
int index = threadIdx.x;
|
||||
ind += batch_index * n;
|
||||
feat += batch_index * c * n;
|
||||
out += batch_index * c * s;
|
||||
cnt += batch_index * s;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
int pos = ind[i];
|
||||
// if (pos == -1)
|
||||
// continue;
|
||||
int cur_cnt = cnt[pos];
|
||||
if (cur_cnt > 0) {
|
||||
float div_cur_cnt = 1.0 / static_cast<float>(cur_cnt);
|
||||
for (int j = 0; j < c; j++) {
|
||||
atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Function: average pool voxelization (backward)
|
||||
Args:
|
||||
b : batch size
|
||||
c : #channels
|
||||
n : number of points
|
||||
r3 : voxel cube size = voxel resolution ** 3
|
||||
ind : voxel index of each point, IntTensor[b, n]
|
||||
cnt : #points in each voxel index, IntTensor[b, s]
|
||||
grad_y : grad outputs, FloatTensor[b, c, s]
|
||||
grad_x : grad inputs, FloatTensor[b, c, n]
|
||||
*/
|
||||
__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3,
|
||||
const int *__restrict__ ind,
|
||||
const int *__restrict__ cnt,
|
||||
const float *__restrict__ grad_y,
|
||||
float *__restrict__ grad_x) {
|
||||
int batch_index = blockIdx.x;
|
||||
int stride = blockDim.x;
|
||||
int index = threadIdx.x;
|
||||
ind += batch_index * n;
|
||||
grad_x += batch_index * c * n;
|
||||
grad_y += batch_index * c * r3;
|
||||
cnt += batch_index * r3;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
int pos = ind[i];
|
||||
// if (pos == -1)
|
||||
// continue;
|
||||
int cur_cnt = cnt[pos];
|
||||
if (cur_cnt > 0) {
|
||||
float div_cur_cnt = 1.0 / static_cast<float>(cur_cnt);
|
||||
for (int j = 0; j < c; j++) {
|
||||
atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
|
||||
const float *feat, int *ind, int *cnt, float *out) {
|
||||
grid_stats_kernel<<<b, optimal_num_threads(n)>>>(b, n, r, r2, r3, coords, ind,
|
||||
cnt);
|
||||
avg_voxelize_kernel<<<b, optimal_num_threads(n)>>>(b, c, n, r3, ind, cnt,
|
||||
feat, out);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
||||
|
||||
void avg_voxelize_grad(int b, int c, int n, int s, const int *ind,
|
||||
const int *cnt, const float *grad_y, float *grad_x) {
|
||||
avg_voxelize_grad_kernel<<<b, optimal_num_threads(n)>>>(b, c, n, s, ind, cnt,
|
||||
grad_y, grad_x);
|
||||
CUDA_CHECK_ERRORS();
|
||||
}
|
10
third_party/pvcnn/functional/src/voxelization/vox.cuh
vendored
Normal file
10
third_party/pvcnn/functional/src/voxelization/vox.cuh
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
#ifndef _VOX_CUH
|
||||
#define _VOX_CUH
|
||||
|
||||
// CUDA function declarations
|
||||
void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
|
||||
const float *feat, int *ind, int *cnt, float *out);
|
||||
void avg_voxelize_grad(int b, int c, int n, int s, const int *idx,
|
||||
const int *cnt, const float *grad_y, float *grad_x);
|
||||
|
||||
#endif
|
15
third_party/pvcnn/functional/src/voxelization/vox.hpp
vendored
Normal file
15
third_party/pvcnn/functional/src/voxelization/vox.hpp
vendored
Normal file
|
@ -0,0 +1,15 @@
|
|||
#ifndef _VOX_HPP
|
||||
#define _VOX_HPP
|
||||
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> avg_voxelize_forward(const at::Tensor features,
|
||||
const at::Tensor coords,
|
||||
const int resolution);
|
||||
|
||||
at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
|
||||
const at::Tensor indices,
|
||||
const at::Tensor cnt);
|
||||
|
||||
#endif
|
47
third_party/pvcnn/functional/voxelization.py
vendored
Normal file
47
third_party/pvcnn/functional/voxelization.py
vendored
Normal file
|
@ -0,0 +1,47 @@
|
|||
from torch.autograd import Function
|
||||
import torch
|
||||
# from modules.functional.backend import _backend
|
||||
from third_party.pvcnn.functional.backend import _backend
|
||||
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
|
||||
|
||||
__all__ = ['avg_voxelize']
|
||||
|
||||
|
||||
class AvgVoxelization(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, features, coords, resolution):
|
||||
"""
|
||||
:param ctx:
|
||||
:param features: Features of the point cloud, FloatTensor[B, C, N]
|
||||
:param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N]
|
||||
:param resolution: Voxel resolution
|
||||
:return:
|
||||
Voxelized Features, FloatTensor[B, C, R, R, R]
|
||||
"""
|
||||
features = features.contiguous()
|
||||
coords = coords.int()[:,:3].contiguous()
|
||||
b, c, _ = features.shape
|
||||
out, indices, counts = _backend.avg_voxelize_forward(
|
||||
features, coords, resolution)
|
||||
ctx.save_for_backward(indices, counts)
|
||||
return out.view(b, c, resolution, resolution, resolution)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
:param ctx:
|
||||
:param grad_output: gradient of output, FloatTensor[B, C, R, R, R]
|
||||
:return:
|
||||
gradient of inputs, FloatTensor[B, C, N]
|
||||
"""
|
||||
b, c = grad_output.shape[:2]
|
||||
indices, counts = ctx.saved_tensors
|
||||
grad_features = _backend.avg_voxelize_backward(
|
||||
grad_output.contiguous().view(b, c, -1), indices, counts)
|
||||
return grad_features, None, None
|
||||
|
||||
|
||||
avg_voxelize = AvgVoxelization.apply
|
||||
|
21
third_party/torchdiffeq/LICENSE
vendored
Normal file
21
third_party/torchdiffeq/LICENSE
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2018 Ricky Tian Qi Chen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
1
third_party/torchdiffeq/README.md
vendored
Normal file
1
third_party/torchdiffeq/README.md
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
adapted from `https://github.com/rtqichen/torchdiffeq/tree/master/torchdiffeq`
|
4
third_party/torchdiffeq/torchdiffeq/__init__.py
vendored
Normal file
4
third_party/torchdiffeq/torchdiffeq/__init__.py
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
from ._impl import odeint
|
||||
from ._impl import odeint_adjoint
|
||||
from ._impl import odeint_event
|
||||
__version__ = "0.2.2"
|
2
third_party/torchdiffeq/torchdiffeq/_impl/__init__.py
vendored
Normal file
2
third_party/torchdiffeq/torchdiffeq/_impl/__init__.py
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .odeint import odeint, odeint_event
|
||||
from .adjoint import odeint_adjoint
|
25
third_party/torchdiffeq/torchdiffeq/_impl/adaptive_heun.py
vendored
Normal file
25
third_party/torchdiffeq/torchdiffeq/_impl/adaptive_heun.py
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
|
||||
|
||||
|
||||
_ADAPTIVE_HEUN_TABLEAU = _ButcherTableau(
|
||||
alpha=torch.tensor([1.], dtype=torch.float64),
|
||||
beta=[
|
||||
torch.tensor([1.], dtype=torch.float64),
|
||||
],
|
||||
c_sol=torch.tensor([0.5, 0.5], dtype=torch.float64),
|
||||
c_error=torch.tensor([
|
||||
0.5,
|
||||
-0.5,
|
||||
], dtype=torch.float64),
|
||||
)
|
||||
|
||||
_AH_C_MID = torch.tensor([
|
||||
0.5, 0.
|
||||
], dtype=torch.float64)
|
||||
|
||||
|
||||
class AdaptiveHeunSolver(RKAdaptiveStepsizeODESolver):
|
||||
order = 2
|
||||
tableau = _ADAPTIVE_HEUN_TABLEAU
|
||||
mid = _AH_C_MID
|
280
third_party/torchdiffeq/torchdiffeq/_impl/adjoint.py
vendored
Normal file
280
third_party/torchdiffeq/torchdiffeq/_impl/adjoint.py
vendored
Normal file
|
@ -0,0 +1,280 @@
|
|||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .odeint import SOLVERS, odeint
|
||||
from .misc import _check_inputs, _flat_to_shape
|
||||
from .misc import _mixed_norm
|
||||
|
||||
|
||||
class OdeintAdjointMethod(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
|
||||
adjoint_options, t_requires_grad, *adjoint_params):
|
||||
|
||||
ctx.shapes = shapes
|
||||
ctx.func = func
|
||||
ctx.adjoint_rtol = adjoint_rtol
|
||||
ctx.adjoint_atol = adjoint_atol
|
||||
ctx.adjoint_method = adjoint_method
|
||||
ctx.adjoint_options = adjoint_options
|
||||
ctx.t_requires_grad = t_requires_grad
|
||||
ctx.event_mode = event_fn is not None
|
||||
|
||||
with torch.no_grad():
|
||||
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
|
||||
|
||||
if event_fn is None:
|
||||
y = ans
|
||||
else:
|
||||
event_t, y = ans
|
||||
ctx.event_t = event_t
|
||||
|
||||
ctx.save_for_backward(t, y, *adjoint_params)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_y):
|
||||
with torch.no_grad():
|
||||
func = ctx.func
|
||||
adjoint_rtol = ctx.adjoint_rtol
|
||||
adjoint_atol = ctx.adjoint_atol
|
||||
adjoint_method = ctx.adjoint_method
|
||||
adjoint_options = ctx.adjoint_options
|
||||
t_requires_grad = ctx.t_requires_grad
|
||||
|
||||
t, y, *adjoint_params = ctx.saved_tensors
|
||||
adjoint_params = tuple(adjoint_params)
|
||||
|
||||
# Backprop as if integrating up to event time.
|
||||
# Does NOT backpropagate through the event time.
|
||||
event_mode = ctx.event_mode
|
||||
if event_mode:
|
||||
event_t = ctx.event_t
|
||||
_t = t
|
||||
t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
|
||||
grad_y = grad_y[1]
|
||||
else:
|
||||
grad_y = grad_y[0]
|
||||
|
||||
##################################
|
||||
# Set up initial state #
|
||||
##################################
|
||||
|
||||
# [-1] because y and grad_y are both of shape (len(t), *y0.shape)
|
||||
aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y
|
||||
aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params
|
||||
|
||||
##################################
|
||||
# Set up backward ODE func #
|
||||
##################################
|
||||
|
||||
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
|
||||
def augmented_dynamics(t, y_aug):
|
||||
# Dynamics of the original system augmented with
|
||||
# the adjoint wrt y, and an integrator wrt t and args.
|
||||
y = y_aug[1]
|
||||
adj_y = y_aug[2]
|
||||
# ignore gradients wrt time and parameters
|
||||
|
||||
with torch.enable_grad():
|
||||
t_ = t.detach()
|
||||
t = t_.requires_grad_(True)
|
||||
y = y.detach().requires_grad_(True)
|
||||
|
||||
# If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
|
||||
# doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
|
||||
# wrt t here means we won't compute that if we don't need it.
|
||||
func_eval = func(t if t_requires_grad else t_, y)
|
||||
|
||||
# Workaround for PyTorch bug #39784
|
||||
_t = torch.as_strided(t, (), ()) # noqa
|
||||
_y = torch.as_strided(y, (), ()) # noqa
|
||||
_params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa
|
||||
|
||||
vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
|
||||
func_eval, (t, y) + adjoint_params, -adj_y,
|
||||
allow_unused=True, retain_graph=True
|
||||
)
|
||||
|
||||
# autograd.grad returns None if no gradient, set to zero.
|
||||
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
|
||||
vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
|
||||
vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
|
||||
for param, vjp_param in zip(adjoint_params, vjp_params)]
|
||||
|
||||
return (vjp_t, func_eval, vjp_y, *vjp_params)
|
||||
|
||||
##################################
|
||||
# Solve adjoint ODE #
|
||||
##################################
|
||||
|
||||
if t_requires_grad:
|
||||
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
|
||||
else:
|
||||
time_vjps = None
|
||||
for i in range(len(t) - 1, 0, -1):
|
||||
if t_requires_grad:
|
||||
# Compute the effect of moving the current time measurement point.
|
||||
# We don't compute this unless we need to, to save some computation.
|
||||
func_eval = func(t[i], y[i])
|
||||
dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
|
||||
aug_state[0] -= dLd_cur_t
|
||||
time_vjps[i] = dLd_cur_t
|
||||
|
||||
# Run the augmented system backwards in time.
|
||||
aug_state = odeint(
|
||||
augmented_dynamics, tuple(aug_state),
|
||||
t[i - 1:i + 1].flip(0),
|
||||
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
|
||||
)
|
||||
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
|
||||
aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state
|
||||
aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point
|
||||
|
||||
if t_requires_grad:
|
||||
time_vjps[0] = aug_state[0]
|
||||
|
||||
# Only compute gradient wrt initial time when in event handling mode.
|
||||
if event_mode and t_requires_grad:
|
||||
time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])
|
||||
|
||||
adj_y = aug_state[2]
|
||||
adj_params = aug_state[3:]
|
||||
|
||||
return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)
|
||||
|
||||
|
||||
def odeint_adjoint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None,
|
||||
adjoint_rtol=None, adjoint_atol=None, adjoint_method=None, adjoint_options=None, adjoint_params=None):
|
||||
|
||||
# We need this in order to access the variables inside this module,
|
||||
# since we have no other way of getting variables along the execution path.
|
||||
if adjoint_params is None and not isinstance(func, nn.Module):
|
||||
raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they '
|
||||
'can be specified explicitly via the `adjoint_params` argument. If there are no parameters '
|
||||
'then it is allowable to set `adjoint_params=()`.')
|
||||
|
||||
# Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options)
|
||||
if adjoint_rtol is None:
|
||||
adjoint_rtol = rtol
|
||||
if adjoint_atol is None:
|
||||
adjoint_atol = atol
|
||||
if adjoint_method is None:
|
||||
adjoint_method = method
|
||||
|
||||
if adjoint_method != method and options is not None and adjoint_options is None:
|
||||
raise ValueError("If `adjoint_method != method` then we cannot infer `adjoint_options` from `options`. So as "
|
||||
"`options` has been passed then `adjoint_options` must be passed as well.")
|
||||
|
||||
if adjoint_options is None:
|
||||
adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {}
|
||||
else:
|
||||
# Avoid in-place modifying a user-specified dict.
|
||||
adjoint_options = adjoint_options.copy()
|
||||
|
||||
if adjoint_params is None:
|
||||
adjoint_params = tuple(find_parameters(func))
|
||||
else:
|
||||
adjoint_params = tuple(adjoint_params) # in case adjoint_params is a generator.
|
||||
|
||||
# Filter params that don't require gradients.
|
||||
oldlen_ = len(adjoint_params)
|
||||
adjoint_params = tuple(p for p in adjoint_params if p.requires_grad)
|
||||
if len(adjoint_params) != oldlen_:
|
||||
# Some params were excluded.
|
||||
# Issue a warning if a user-specified norm is specified.
|
||||
if 'norm' in adjoint_options and callable(adjoint_options['norm']):
|
||||
warnings.warn("An adjoint parameter was passed without requiring gradient. For efficiency this will be "
|
||||
"excluded from the adjoint pass, and will not appear as a tensor in the adjoint norm.")
|
||||
|
||||
# Convert to flattened state.
|
||||
shapes, func, y0, t, rtol, atol, method, options, event_fn, decreasing_time = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
|
||||
|
||||
# Handle the adjoint norm function.
|
||||
state_norm = options["norm"]
|
||||
handle_adjoint_norm_(adjoint_options, shapes, state_norm)
|
||||
|
||||
ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,
|
||||
adjoint_method, adjoint_options, t.requires_grad, *adjoint_params)
|
||||
|
||||
if event_fn is None:
|
||||
solution = ans
|
||||
else:
|
||||
event_t, solution = ans
|
||||
event_t = event_t.to(t)
|
||||
if decreasing_time:
|
||||
event_t = -event_t
|
||||
|
||||
if shapes is not None:
|
||||
solution = _flat_to_shape(solution, (len(t),), shapes)
|
||||
|
||||
if event_fn is None:
|
||||
return solution
|
||||
else:
|
||||
return event_t, solution
|
||||
|
||||
|
||||
def find_parameters(module):
|
||||
|
||||
assert isinstance(module, nn.Module)
|
||||
|
||||
# If called within DataParallel, parameters won't appear in module.parameters().
|
||||
if getattr(module, '_is_replica', False):
|
||||
|
||||
def find_tensor_attributes(module):
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) and v.requires_grad]
|
||||
return tuples
|
||||
|
||||
gen = module._named_members(get_members_fn=find_tensor_attributes)
|
||||
return [param for _, param in gen]
|
||||
else:
|
||||
return list(module.parameters())
|
||||
|
||||
|
||||
def handle_adjoint_norm_(adjoint_options, shapes, state_norm):
|
||||
"""In-place modifies the adjoint options to choose or wrap the norm function."""
|
||||
|
||||
# This is the default adjoint norm on the backward pass: a mixed norm over the tuple of inputs.
|
||||
def default_adjoint_norm(tensor_tuple):
|
||||
t, y, adj_y, *adj_params = tensor_tuple
|
||||
# (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
|
||||
return max(t.abs(), state_norm(y), state_norm(adj_y), _mixed_norm(adj_params))
|
||||
|
||||
if "norm" not in adjoint_options:
|
||||
# `adjoint_options` was not explicitly specified by the user. Use the default norm.
|
||||
adjoint_options["norm"] = default_adjoint_norm
|
||||
else:
|
||||
# `adjoint_options` was explicitly specified by the user...
|
||||
try:
|
||||
adjoint_norm = adjoint_options['norm']
|
||||
except KeyError:
|
||||
# ...but they did not specify the norm argument. Back to plan A: use the default norm.
|
||||
adjoint_options['norm'] = default_adjoint_norm
|
||||
else:
|
||||
# ...and they did specify the norm argument.
|
||||
if adjoint_norm == 'seminorm':
|
||||
# They told us they want to use seminorms. Slight modification to plan A: use the default norm,
|
||||
# but ignore the parameter state
|
||||
def adjoint_seminorm(tensor_tuple):
|
||||
t, y, adj_y, *adj_params = tensor_tuple
|
||||
# (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
|
||||
return max(t.abs(), state_norm(y), state_norm(adj_y))
|
||||
adjoint_options['norm'] = adjoint_seminorm
|
||||
else:
|
||||
# And they're using their own custom norm.
|
||||
if shapes is None:
|
||||
# The state on the forward pass was a tensor, not a tuple. We don't need to do anything, they're
|
||||
# already going to get given the full adjoint state as (t, y, adj_y, adj_params)
|
||||
pass # this branch included for clarity
|
||||
else:
|
||||
# This is the bit that is tuple/tensor abstraction-breaking, because the odeint machinery
|
||||
# doesn't know about the tupled nature of the forward state. We need to tell the user's adjoint
|
||||
# norm about that ourselves.
|
||||
|
||||
def _adjoint_norm(tensor_tuple):
|
||||
t, y, adj_y, *adj_params = tensor_tuple
|
||||
y = _flat_to_shape(y, (), shapes)
|
||||
adj_y = _flat_to_shape(adj_y, (), shapes)
|
||||
return adjoint_norm((t, *y, *adj_y, *adj_params))
|
||||
adjoint_options['norm'] = _adjoint_norm
|
22
third_party/torchdiffeq/torchdiffeq/_impl/bosh3.py
vendored
Normal file
22
third_party/torchdiffeq/torchdiffeq/_impl/bosh3.py
vendored
Normal file
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
|
||||
|
||||
|
||||
_BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau(
|
||||
alpha=torch.tensor([1 / 2, 3 / 4, 1.], dtype=torch.float64),
|
||||
beta=[
|
||||
torch.tensor([1 / 2], dtype=torch.float64),
|
||||
torch.tensor([0., 3 / 4], dtype=torch.float64),
|
||||
torch.tensor([2 / 9, 1 / 3, 4 / 9], dtype=torch.float64)
|
||||
],
|
||||
c_sol=torch.tensor([2 / 9, 1 / 3, 4 / 9, 0.], dtype=torch.float64),
|
||||
c_error=torch.tensor([2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8], dtype=torch.float64),
|
||||
)
|
||||
|
||||
_BS_C_MID = torch.tensor([0., 0.5, 0., 0.], dtype=torch.float64)
|
||||
|
||||
|
||||
class Bosh3Solver(RKAdaptiveStepsizeODESolver):
|
||||
order = 3
|
||||
tableau = _BOGACKI_SHAMPINE_TABLEAU
|
||||
mid = _BS_C_MID
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue