This commit is contained in:
xzeng 2023-01-23 00:14:49 -05:00
parent 1118b9b0c0
commit 1d24a7879d
134 changed files with 18308 additions and 10 deletions

14
.gitignore vendored Normal file
View file

@ -0,0 +1,14 @@
__pycache__/
*__pycache__
.idea/
*.pyc
*.m
.ipynb_checkpoints
*swp
*swo
*__pycache__*
models/pvcnn/functional/build/
*.sh
lion_ckpt
data/
datasets/test_data

View file

@ -1,20 +1,76 @@
## <p align="center">LION: Latent Point Diffusion Models for 3D Shape Generation<br><br> NeurIPS 2022 </p> ## <p align="center">LION: Latent Point Diffusion Models for 3D Shape Generation<br><br> NeurIPS 2022 </p>
<div align="center"> <div align="center">
<a href="https://www.cs.utoronto.ca/~xiaohui/" target="_blank">Xiaohui&nbsp;Zeng</a> &emsp; <b>&middot;</b> &emsp; <a href="https://www.cs.utoronto.ca/~xiaohui/" target="_blank">Xiaohui&nbsp;Zeng</a> &emsp;
<a href="http://latentspace.cc/" target="_blank">Arash&nbsp;Vahdat</a> &emsp; <b>&middot;</b> &emsp; <a href="http://latentspace.cc/" target="_blank">Arash&nbsp;Vahdat</a> &emsp;
<a href="https://www.fwilliams.info/" target="_blank">Francis&nbsp;Williams</a> &emsp; <b>&middot;</b> &emsp; <a href="https://www.fwilliams.info/" target="_blank">Francis&nbsp;Williams</a> &emsp;
<a href="https://zgojcic.github.io/" target="_blank">Zan&nbsp;Gojcic</a> &emsp; <b>&middot;</b> &emsp; <a href="https://zgojcic.github.io/" target="_blank">Zan&nbsp;Gojcic</a> &emsp;
<a href="https://orlitany.github.io/" target="_blank">Or&nbsp;Litany</a> &emsp; <b>&middot;</b> &emsp; <a href="https://orlitany.github.io/" target="_blank">Or&nbsp;Litany</a> &emsp;
<a href="https://www.cs.utoronto.ca/~fidler/" target="_blank">Sanja&nbsp;Fidler</a> &emsp; <b>&middot;</b> &emsp; <a href="https://www.cs.utoronto.ca/~fidler/" target="_blank">Sanja&nbsp;Fidler</a> &emsp;
<a href="https://karstenkreis.github.io/" target="_blank">Karsten&nbsp;Kreis</a> <a href="https://karstenkreis.github.io/" target="_blank">Karsten&nbsp;Kreis</a>
<br> <br> <br> <br>
<a href="https://arxiv.org/abs/2210.06978" target="_blank">Paper</a> &emsp; <a href="https://arxiv.org/abs/2210.06978" target="_blank">Paper</a> &emsp;
<a href="https://nv-tlabs.github.io/LION" target="_blank">Project&nbsp;Page</a> <a href="https://nv-tlabs.github.io/LION" target="_blank">Project&nbsp;Page</a>
</div> </div>
<br><br>
<p align="center">:construction: :pick: :hammer_and_wrench: :construction_worker:</p>
<p align="center">Here, we will release code and checkpoints in the near future! Stay tuned!</p>
<br><br>
<p align="center"> <p align="center">
<img width="750" alt="Animation" src="assets/animation.gif"/> <img width="750" alt="Animation" src="assets/animation.gif"/>
</p> </p>
## Install
* Dependencies:
* CUDA 11.6
* Setup the environment
Install from conda file
```
conda env create --name lion_env --file=env.yaml
conda activate lion_env
# Install some other packages
pip install git+https://github.com/openai/CLIP.git
# build some packages first (optional)
python build_pkg.py
```
Tested with conda version 22.9.0
## Demo
run `python demo.py`, will load the released text2shape model on hugging face and generate a chair point cloud.
## Released checkpoint and samples
* will be release soon
* put the downloaded file under `./lion_ckpt/`
## Training
### data
* ShapeNet can be downloaded [here](https://github.com/stevenygd/PointFlow#dataset).
* Put the downloaded data as `./data/ShapeNetCore.v2.PC15k` *or* edit the `pointflow` entry in `./datasets/data_path.py` for the ShapeNet dataset path.
### train VAE
* run `bash ./script/train_vae.sh $NGPU` (the released checkpoint is trained with `NGPU=4`)
### train diffusion prior
* require the vae checkpoint
* run `bash ./script/train_prior.sh $NGPU` (the released checkpoint is trained with `NGPU=8` with 2 node)
### evaluate a trained prior
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`
* download the released checkpoint from above
```
checkpoint="./lion_ckpt/unconditional/airplane/checkpoints/model.pt"
bash ./script/eval.sh $checkpoint # will take 1-2 hour
```
## Evaluate the samples with the 1-NNA metrics
* download the test data from [here](https://drive.google.com/file/d/1uEp0o6UpRqfYwvRXQGZ5ZgT1IYBQvUSV/view?usp=share_link), unzip and put it as `./datasets/test_data/`
* run `python ./script/compute_score.py`
## Citation
```
@inproceedings{zeng2022lion,
title={LION: Latent Point Diffusion Models for 3D Shape Generation},
author={Xiaohui Zeng and Arash Vahdat and Francis Williams and Zan Gojcic and Or Litany and Sanja Fidler and Karsten Kreis},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
```

3
build_pkg.py Normal file
View file

@ -0,0 +1,3 @@
import clip
from models import pvcnn2
from utils import eval_helper

38
datasets/data_path.py Normal file
View file

@ -0,0 +1,38 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import os
def get_path(dataname=None):
dataset_path = {}
dataset_path['pointflow'] = [
'./data/ShapeNetCore.v2.PC15k/'
]
if dataname is None:
return dataset_path
else:
assert(
dataname in dataset_path), f'not found {dataname}, only: {list(dataset_path.keys())}'
for p in dataset_path[dataname]:
print(f'searching: {dataname}, get: {p}')
if os.path.exists(p):
return p
ValueError(
f'all path not found for {dataname}, please double check: {dataset_path[dataname]}; or edit the datasets/data_path.py ')
def get_cache_path():
cache_list = ['/workspace/data_cache_local/data_stat/',
'/workspace/data_cache/data_stat/']
for p in cache_list:
if os.path.exists(p):
return p
ValueError(
f'all path not found for {cache_list}, please double check: or edit the datasets/data_path.py ')

View file

@ -0,0 +1,404 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" copied and modified from https://github.com/stevenygd/PointFlow/blob/master/datasets.py """
import os
import open3d as o3d
import time
import torch
import numpy as np
from loguru import logger
from torch.utils.data import Dataset
from torch.utils import data
import random
import tqdm
from datasets.data_path import get_path
OVERFIT = 0
# taken from https://github.com/optas/latent_3d_points/blob/
# 8e8f29f8124ed5fc59439e8551ba7ef7567c9a37/src/in_out.py
synsetid_to_cate = {
'02691156': 'airplane',
'02773838': 'bag',
'02801938': 'basket',
'02808440': 'bathtub',
'02818832': 'bed',
'02828884': 'bench',
'02876657': 'bottle',
'02880940': 'bowl',
'02924116': 'bus',
'02933112': 'cabinet',
'02747177': 'can',
'02942699': 'camera',
'02954340': 'cap',
'02958343': 'car',
'03001627': 'chair',
'03046257': 'clock',
'03207941': 'dishwasher',
'03211117': 'monitor',
'04379243': 'table',
'04401088': 'telephone',
'02946921': 'tin_can',
'04460130': 'tower',
'04468005': 'train',
'03085013': 'keyboard',
'03261776': 'earphone',
'03325088': 'faucet',
'03337140': 'file',
'03467517': 'guitar',
'03513137': 'helmet',
'03593526': 'jar',
'03624134': 'knife',
'03636649': 'lamp',
'03642806': 'laptop',
'03691459': 'speaker',
'03710193': 'mailbox',
'03759954': 'microphone',
'03761084': 'microwave',
'03790512': 'motorcycle',
'03797390': 'mug',
'03928116': 'piano',
'03938244': 'pillow',
'03948459': 'pistol',
'03991062': 'pot',
'04004475': 'printer',
'04074963': 'remote_control',
'04090263': 'rifle',
'04099429': 'rocket',
'04225987': 'skateboard',
'04256520': 'sofa',
'04330267': 'stove',
'04530566': 'vessel',
'04554684': 'washer',
'02992529': 'cellphone',
'02843684': 'birdhouse',
'02871439': 'bookshelf',
# '02858304': 'boat', no boat in our dataset, merged into vessels
# '02834778': 'bicycle', not in our taxonomy
}
cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()}
class ShapeNet15kPointClouds(Dataset):
def __init__(self,
categories=['airplane'],
tr_sample_size=10000,
te_sample_size=10000,
split='train',
scale=1.,
normalize_per_shape=False,
normalize_shape_box=False,
random_subsample=False,
sample_with_replacement=1,
normalize_std_per_axis=False,
normalize_global=False,
recenter_per_shape=False,
all_points_mean=None,
all_points_std=None,
input_dim=3,
):
self.normalize_shape_box = normalize_shape_box
root_dir = get_path('pointflow')
self.root_dir = root_dir
logger.info('[DATA] cat: {}, split: {}, full path: {}; norm global={}, norm-box={}',
categories, split, self.root_dir, normalize_global, normalize_shape_box)
self.split = split
assert self.split in ['train', 'test', 'val']
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
if type(categories) is str:
categories = [categories]
self.cates = categories
if 'all' in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
subdirs = self.synset_ids
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.sample_with_replacement = sample_with_replacement
self.input_dim = input_dim
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
tic = time.time()
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s " % (sub_path))
raise ValueError('check the data path')
continue
if True:
all_mids = []
assert(os.path.exists(sub_path)), f'path missing: {sub_path}'
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
logger.info('[DATA] number of file [{}] under: {} ',
len(os.listdir(sub_path)), sub_path)
# NOTE: [mid] contains the split: i.e. "train/<mid>"
# or "val/<mid>" or "test/<mid>"
all_mids = sorted(all_mids)
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
point_cloud = np.load(obj_fname) # (15k, 3)
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
logger.info('[DATA] Load data time: {:.1f}s | dir: {} | '
'sample_with_replacement: {}; num points: {}', time.time() - tic, self.subdirs,
self.sample_with_replacement, len(self.all_points))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
self.recenter_per_shape = recenter_per_shape
if self.normalize_shape_box: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = ( # B,1,3
(np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)) / 2
self.all_points_std = np.amax( # B,1,1
((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
axis=-1).reshape(B, 1, 1) / 2
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(
B, 1, input_dim)
logger.info('all_points shape: {}. mean over axis=1',
self.all_points.shape)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(
B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(
B, -1).std(axis=1).reshape(B, 1, 1)
elif all_points_mean is not None and all_points_std is not None and not self.recenter_per_shape:
# using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.recenter_per_shape: # per shape center
# TODO: bounding box scale at the large dim and center
B, N = self.all_points.shape[:2]
self.all_points_mean = (
(np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) +
(np.amin(self.all_points, axis=1)).reshape(B, 1,
input_dim)) / 2
self.all_points_std = np.amax(
((np.amax(self.all_points, axis=1)).reshape(B, 1, input_dim) -
(np.amin(self.all_points, axis=1)).reshape(B, 1, input_dim)),
axis=-1).reshape(B, 1, 1) / 2
# else: # normalize across the dataset
elif normalize_global: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(
-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(
-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(
axis=0).reshape(1, 1, 1)
logger.info('[DATA] normalize_global: mean={}, std={}',
self.all_points_mean.reshape(-1),
self.all_points_std.reshape(-1))
else:
raise NotImplementedError('No Normalization')
self.all_points = (self.all_points - self.all_points_mean) / \
self.all_points_std
logger.info('[DATA] shape={}, all_points_mean:={}, std={}, max={:.3f}, min={:.3f}; num-pts={}',
self.all_points.shape,
self.all_points_mean.shape, self.all_points_std.shape,
self.all_points.max(), self.all_points.min(), tr_sample_size)
if OVERFIT:
self.all_points = self.all_points[:40]
# TODO: why do we need this??
self.train_points = self.all_points[:, :min(
10000, self.all_points.shape[1])]
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
assert self.scale == 1, "Scale (!= 1) is deprecated"
# Default display axis order
self.display_axis_order = [0, 1, 2]
def get_pc_stats(self, idx):
if self.recenter_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
if self.normalize_per_shape or self.normalize_shape_box:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), \
self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + \
self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / \
self.all_points_std
self.train_points = self.all_points[:, :min(
10000, self.all_points.shape[1])]
## self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
output = {}
tr_out = self.train_points[idx]
if self.random_subsample and self.sample_with_replacement:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
elif self.random_subsample and not self.sample_with_replacement:
tr_idxs = np.random.permutation(
np.arange(tr_out.shape[0]))[:self.tr_sample_size]
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
input_pts = tr_out
output.update(
{
'idx': idx,
'select_idx': tr_idxs,
'tr_points': tr_out,
'input_pts': input_pts,
'mean': m,
'std': s,
'cate_idx': cate_idx,
'sid': sid,
'mid': mid,
'display_axis_order': self.display_axis_order
})
return output
def init_np_seed(worker_id):
seed = torch.initial_seed()
np.random.seed(seed % 4294967296)
def get_datasets(cfg, args):
"""
cfg: config.data sub part
"""
if OVERFIT:
random_subsample = 0
else:
random_subsample = cfg.random_subsample
logger.info(f'get_datasets: tr_sample_size={cfg.tr_max_sample_points}, '
f' te_sample_size={cfg.te_max_sample_points}; '
f' random_subsample={random_subsample}'
f' normalize_global={cfg.normalize_global}'
f' normalize_std_per_axix={cfg.normalize_std_per_axis}'
f' normalize_per_shape={cfg.normalize_per_shape}'
f' recenter_per_shape={cfg.recenter_per_shape}'
)
kwargs = {}
tr_dataset = ShapeNet15kPointClouds(
categories=cfg.cates,
split='train',
tr_sample_size=cfg.tr_max_sample_points,
te_sample_size=cfg.te_max_sample_points,
sample_with_replacement=cfg.sample_with_replacement,
scale=cfg.dataset_scale, # root_dir=cfg.data_dir,
normalize_shape_box=cfg.normalize_shape_box,
normalize_per_shape=cfg.normalize_per_shape,
normalize_std_per_axis=cfg.normalize_std_per_axis,
normalize_global=cfg.normalize_global,
recenter_per_shape=cfg.recenter_per_shape,
random_subsample=random_subsample,
**kwargs)
eval_split = getattr(args, "eval_split", "val")
# te_dataset has random_subsample as False, therefore not using sample_with_replacement
te_dataset = ShapeNet15kPointClouds(
categories=cfg.cates,
split=eval_split,
tr_sample_size=cfg.tr_max_sample_points,
te_sample_size=cfg.te_max_sample_points,
scale=cfg.dataset_scale, # root_dir=cfg.data_dir,
normalize_shape_box=cfg.normalize_shape_box,
normalize_per_shape=cfg.normalize_per_shape,
normalize_std_per_axis=cfg.normalize_std_per_axis,
normalize_global=cfg.normalize_global,
recenter_per_shape=cfg.recenter_per_shape,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
def get_data_loaders(cfg, args):
tr_dataset, te_dataset = get_datasets(cfg, args)
kwargs = {}
if args.distributed:
kwargs['sampler'] = data.distributed.DistributedSampler(
tr_dataset, shuffle=True)
else:
kwargs['shuffle'] = True
if args.eval_trainnll:
kwargs['shuffle'] = False
train_loader = data.DataLoader(dataset=tr_dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
drop_last=cfg.train_drop_last == 1,
pin_memory=False, **kwargs)
test_loader = data.DataLoader(dataset=te_dataset,
batch_size=cfg.batch_size_test,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=False,
drop_last=False,
)
logger.info(
f'[Batch Size] train={cfg.batch_size}, test={cfg.batch_size_test}; drop-last={cfg.train_drop_last}')
loaders = {
"test_loader": test_loader,
'train_loader': train_loader,
}
return loaders

450
default_config.py Normal file
View file

@ -0,0 +1,450 @@
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
# ---------------------------------------------------------------
from third_party.yacs_config import CfgNode as CN
cfg = CN()
cfg.dpm_ckpt = ''
cfg.clipforge = CN()
cfg.clipforge.clip_model = "ViT-B/32"
cfg.clipforge.enable = 0
cfg.clipforge.feat_dim = 512
cfg.eval_trainnll = 0
cfg.exp_name = ''
cfg.cmt = ''
cfg.hash = ''
cfg.ngpu = 1
cfg.snapshot_min = 30 # snapshot every 30 min
cfg.bash_name = ''
cfg.set_detect_anomaly = 0
cfg.weight_recont = 1.0
# vae ckpt
# lns
cfg.use_checkpoint = 0
cfg.num_val_samples = 16 # 24 #12
# config for pointtransformer
cfg.eval = CN()
cfg.eval.need_denoise = 0
cfg.eval.load_other_vae_ckpt = 0
cfg.register_deprecated_key('eval.other_vae_ckpt_path')
cfg.vis_latent_point = 0
cfg.latent_pts = CN()
#cfg.latent_pts.class_embed_layer = ''
cfg.register_deprecated_key('latent_pts.class_embed_layer')
cfg.latent_pts.style_dim = 128 # dim of global style latent variable
cfg.register_deprecated_key('latent_pts.perturb_input')
cfg.register_deprecated_key('latent_pts.perturb_input_scale')
cfg.register_deprecated_key('latent_pts.outlier_input')
# scale of init weights for the mlp in adaGN layer
cfg.latent_pts.ada_mlp_init_scale = 1.0
# models.latent_points_ada.StyleMLP' # style mlp layers
cfg.latent_pts.style_mlp = ''
cfg.latent_pts.pts_sigma_offset = 0.0
cfg.latent_pts.skip_weight = 0.1
cfg.latent_pts.encoder_layer_out_dim = 32
cfg.latent_pts.decoder_layer_out_dim = 32
cfg.register_deprecated_key('latent_pts.encoder_nneighbor')
cfg.register_deprecated_key('latent_pts.decoder_nneighbor')
cfg.latent_pts.style_prior = 'models.score_sde.resnet.PriorSEDrop'
cfg.latent_pts.mask_out_extra_latent = 0 # use only latent coordinates
# latent coordinates directly same as input (not using the decoder and encoder)
cfg.register_deprecated_key('latent_pts.latent_as_pts')
cfg.latent_pts.normalization = 'bn' # BatchNorm or LayerNorm
cfg.latent_pts.pvd_mse_loss = 0
cfg.latent_pts.hid = 64
cfg.register_deprecated_key('latent_pts.knn')
cfg.register_deprecated_key('latent_pts.n5layer')
cfg.register_deprecated_key('latent_pts.dgcnn_last_hid')
cfg.latent_pts.latent_dim_ext = [64] # the global latent dim
cfg.latent_pts.weight_kl_pt = 1.0 # kl ratio of the pts
cfg.latent_pts.weight_kl_feat = 1.0 # kl ratio of the latent feat
cfg.latent_pts.weight_kl_glb = 1.0 # kl ratio of the latent feat
# kl ratio of the latent feat
cfg.latent_pts.style_encoder = 'models.shapelatent_modules.PointNetPlusEncoder'
cfg.latent_pts.use_linear_for_adagn = 0
# cfg.latent_pts.weight_kl_glb = 1.0 # kl ratio of the global latent
# shapelatent:
cfg.has_shapelatent = 1
cfg.shapelatent = CN()
cfg.shapelatent.local_emb_agg = 'mean'
cfg.shapelatent.freeze_vae = 0 # learn vae
cfg.shapelatent.eps_z_global_only = 1
cfg.shapelatent.model = 'flow'
cfg.shapelatent.residual = 1
cfg.shapelatent.encoder_type = 'pointnet'
cfg.shapelatent.prior_type = 'flow'
cfg.shapelatent.decoder_type = 'PointwiseNet'
cfg.shapelatent.loss0_weight = 1.0
cfg.shapelatent.latent_dim = 256
cfg.shapelatent.kl_weight = 1e-3
cfg.shapelatent.decoder_num_points = -1
# offset the sigma towards zero for better init, will use the log_sigma - offset value, better to be positive s.t. - offset < 0 since we'd like to push it towards 0; exp(-0.1)=0.9, exp(-0.8)=0.44, exp(-1)=0.3, exp(-10)=4e-5
cfg.shapelatent.log_sigma_offset = 0.0
cfg.sde = CN()
cfg.sde.ode_sample = 0 #1
# train the prior or not, default is 1, only when we do voxel2pts, will freeze prior
cfg.sde.train_dae = 1
cfg.sde.init_t = 1.0 # start from time = 1.0
cfg.sde.nhead = 4 # number of head in transformder: multi-head attention layer
cfg.sde.local_prior = 'same_as_global' # architecture for local prior
cfg.sde.drop_inactive_var = 0
cfg.sde.learn_mixing_logit = 1 # freeze it
cfg.sde.regularize_mlogit_margin = 0.0
cfg.sde.share_mlogit = 0 # use same mlogit for all latent variables
cfg.sde.hypara_mixing_logit = 0 # set as hyper-parameter and freeze it?
cfg.sde.bound_mlogit = 0 # clamp or not
cfg.sde.bound_mlogit_value = -5.42 # clamp the max value
cfg.sde.regularize_mlogit = 0 # set the sum of sigmoid(mlogit) as one loss
cfg.sde.attn_mhead = 0 # use multi-head attention in prior model
cfg.sde.attn_mhead_local = -1 # use multi-head attention in prior model
cfg.sde.pos_embed = 'none'
cfg.sde.hier_prior = 0
cfg.sde.is_continues = 0
cfg.sde.time_emb_scales = 1.0 # -> 1k?
cfg.sde.time_eps = 1e-2
cfg.sde.ode_eps = 1e-5 # cut off for ode sampling
cfg.sde.sde_type = 'vpsde' # vada
cfg.sde.sigma2_0 = 0.0
cfg.sde.sigma2_max = 0.99
cfg.sde.sigma2_min = 1e-4
cfg.sde.beta_start = 0.1 # 1e-4 * 1e3
cfg.sde.beta_end = 20.0 # 1e-2 * 1e3
# sampling, always iw # ll: small times; 'll_uniform' # -> ll_iw
cfg.sde.iw_sample_p = 'll_iw'
# drop_all_iw / drop_sigma2t_iw
cfg.sde.iw_subvp_like_vp_sde = False
cfg.sde.prior_model = 'models.latent_points_ada_localprior.PVCNN2Prior'
# -- to train diffusion in latent space -- #
cfg.sde.update_q_ema = False
cfg.sde.iw_sample_q = 'reweight_p_samples'
# ll_iw / reweight_p_samples
cfg.sde.kl_anneal_portion_vada = 0.1
cfg.sde.kl_const_portion_vada = 0.0
cfg.sde.kl_const_coeff_vada = 0.7
cfg.sde.kl_balance_vada = False
cfg.sde.grad_clip_max_norm = 0.0
cfg.sde.cont_kl_anneal = True
# False
cfg.sde.mixing_logit_init = -6
cfg.sde.weight_decay_norm_vae = 0.0 #1e-2
cfg.sde.weight_decay_norm_dae = 0.0 #1e-2
# -> 0, for sn calculator
cfg.sde.train_vae = True
cfg.sde.jac_reg_coeff = 0
cfg.sde.jac_reg_freq = 1
cfg.sde.kin_reg_coeff = 0
cfg.sde.learning_rate_mlogit = -1.0
cfg.sde.learning_rate_dae_local = 3e-4
cfg.sde.learning_rate_min_dae_local = 3e-4
cfg.sde.learning_rate_dae = 3e-4
cfg.sde.learning_rate_min_dae = 3e-4
cfg.sde.learning_rate_min_vae = 1e-5
cfg.sde.learning_rate_vae = 1e-4
cfg.sde.epochs = 800
cfg.sde.warmup_epochs = 20
cfg.sde.weight_decay = 3e-4
cfg.sde.use_adamax = False
cfg.sde.use_adam = True # False
cfg.sde.mixed_prediction = False # True
cfg.sde.vae_checkpoint = ''
cfg.sde.dae_checkpoint = ''
# will be used to multiply with the t value, if ode solver, use 1k, if discrete solver, use 1.0
cfg.sde.embedding_scale = 1.0 # 1000.0
cfg.sde.embedding_type = 'positional'
cfg.sde.train_ode_solver_tol = 1e-5
cfg.sde.num_scales_dae = 2
cfg.sde.autocast_train = False
cfg.sde.diffusion_steps = 1000
cfg.sde.embedding_dim = 128
cfg.sde.num_channels_dae = 256
cfg.sde.num_cell_per_scale_dae = 8
cfg.sde.num_cell_per_scale_dae_local = 0
cfg.sde.dropout = 0.2
cfg.sde.num_preprocess_blocks = 2
cfg.sde.num_latent_scales = 1
cfg.sde.fir = False
cfg.sde.progressive = 'none'
cfg.sde.progressive_input = 'none'
cfg.sde.progressive_combine = 'sum'
cfg.sde.dataset = 'shape'
cfg.sde.denoising_stddevs = 'beta'
cfg.sde.ema_decay = 0.9999
# cfg.sde.is_train_vae=True
cfg.register_deprecated_key("sde.is_train_vae")
cfg.sde.kl_max_coeff_vada = 1.0
# conditional prior input
cfg.sde.condition_add = 1
cfg.sde.condition_cat = 0
cfg.sde.global_prior_ckpt = '' # checkpoint for global prior component
cfg.sde.pool_feat_cat = 0 # the local prior aggregate the feat as extra input channels
# hyperparameter of ddim sampling
cfg.sde.ddim_skip_type = 'uniform'
cfg.sde.ddim_kappa = 1.0 # 1.0: fully ddpm sampling; 0: ode style sampling
cfg.ddpm = CN()
cfg.ddpm.use_p2_weight = 0
cfg.ddpm.p2_k = 1.0
cfg.ddpm.p2_gamma = 1.0
cfg.ddpm.use_new_timeemb = 0
cfg.ddpm.input_dim = 3
cfg.ddpm.dropout = 0.1
cfg.ddpm.num_layers_classifier = 3
cfg.ddpm.use_bn = True
cfg.ddpm.add_point_feat = True
cfg.ddpm.use_gn = False
cfg.ddpm.time_dim = 64
cfg.ddpm.ema = 1
cfg.ddpm.with_se = 0
cfg.ddpm.use_global_attn = 0
cfg.ddpm.num_steps = 1000
cfg.ddpm.beta_1 = 1e-4
cfg.ddpm.beta_T = 2e-2
# ['linear', 'customer'] 'customer' for airplane in PVD
cfg.ddpm.sched_mode = 'linear'
cfg.ddpm.model_var_type = 'fixedlarge'
# define architecture:
cfg.register_deprecated_key("ddpm.pointnet_plus")
cfg.register_deprecated_key("ddpm.pointnet_pp")
cfg.register_deprecated_key("ddpm.pointnet_luo")
# end define architecture
#cfg.ddpm.use_pvc = 1
cfg.register_deprecated_key("ddpm.use_pvc")
cfg.ddpm.clip_denoised = 0
cfg.ddpm.model_mean_type = 'eps'
cfg.ddpm.loss_type = 'mse'
cfg.ddpm.loss_type_0 = ''
cfg.ddpm.loss_weight_emd = 0.02
cfg.ddpm.loss_weight_cdnorm = 1.0
cfg.ddpm.attn = [0, 1, 0, 0]
cfg.ddpm.ncenter = [1024, 256, 64, 16]
#cfg.ddpm.pvc = CN()
#cfg.ddpm.pvc.use_small_model = 0
#cfg.ddpm.pvc.mlp_after_pvc = 0
cfg.register_deprecated_key("ddpm.pvc")
cfg.register_deprecated_key("ddpm.pvc.use_small_model")
cfg.register_deprecated_key("ddpm.pvc.mlp_after_pvc")
cfg.ddpm.ddim_step = 200
cfg.data = CN()
cfg.data.nclass = 55
cfg.data.cond_on_cat = 0
cfg.data.cond_on_voxel = 0
cfg.data.eval_test_split = 0 # eval loader will be using test split
cfg.data.voxel_size = 0.1 # size of voxel for voxel_datasets.py
cfg.data.noise_std = 0.1 # std for the noise added to the input data
cfg.data.noise_type = 'normal' # std for the noise added to the input data
cfg.data.noise_std_min = -1.0 # for range of noise std
cfg.data.clip_forge_enable = 0
cfg.data.clip_model = 'ViT-B/32'
cfg.data.type = "datasets.pointflow_datasets"
# datasets/neuralspline_datasets datasets/shape_curvature
cfg.data.dataset_type = "shapenet15k"
cfg.data.num_workers = 12 # 8
cfg.data.train_drop_last = 1 # drop_last for train data loader
cfg.data.cates = 'chair' # data category
cfg.data.tr_max_sample_points = 2048
cfg.data.te_max_sample_points = 2048
cfg.data.data_dir = "data/ShapeNetCore.v2.PC15k" # depreciated
cfg.data.batch_size = 12
cfg.data.batch_size_test = 10
cfg.data.dataset_scale = 1
# -- the following option in terms of normalization should turn into string -- #
cfg.data.normalize_per_shape = False
cfg.data.normalize_shape_box = False
cfg.data.normalize_global = False
cfg.data.normalize_std_per_axis = False
cfg.data.normalize_range = False # not used
cfg.data.recenter_per_shape = True
# -- for the normal prediction model, used in folder_datasets
cfg.register_deprecated_key('data.load_point_stat')
cfg.register_deprecated_key('data.is_load_pointflow2NS')
cfg.register_deprecated_key('data.data_path')
#
cfg.data.sample_with_replacement = 1
# fixed the data.tr_max_sample_points $np data.te_max_sample_points $np2048 points of the first 15k points
cfg.data.random_subsample = 1
# the data dim, used in dataset worker, if -1, it will be the same as ddpm.input_dim
cfg.data.input_dim = -1
cfg.data.is_encode_whole_dataset_trainer = 0
cfg.register_deprecated_key('data.augment')
cfg.register_deprecated_key('data.aug_translate')
cfg.register_deprecated_key('data.aug_scale')
cfg.register_deprecated_key('data.sub_train_set')
cfg.test_size = 660
cfg.viz = CN()
cfg.viz.log_freq = 10
cfg.viz.viz_freq = 400
cfg.viz.save_freq = 200
cfg.viz.val_freq = -1
cfg.viz.viz_order = [2, 0, 1]
cfg.viz.vis_sample_ddim_step = 0
cfg.trainer = CN()
# when loss 1 is weighted, also weight the kl terms
cfg.trainer.apply_loss_weight_1_kl = 0
cfg.trainer.kl_free = [0, 0] # the value for the threshold
# not back ward kl loss if KL value is smaller than the threshold
cfg.trainer.use_kl_free = 0
cfg.trainer.type = "trainers.ddpm_trainer" # it means dist trainer
cfg.trainer.epochs = 10000
cfg.trainer.warmup_epochs = 0
cfg.trainer.seed = 1
cfg.trainer.use_grad_scalar = 0
cfg.trainer.opt = CN()
cfg.trainer.opt.type = 'adam'
cfg.trainer.opt.lr = 1e-4 # use bs*1e-5/8
cfg.trainer.opt.lr_min = 1e-4 # use bs*1e-5/8
# lr start to anneal after ratio of epochs; used in cosine and lambda lr scheduler
cfg.trainer.opt.start_ratio = 0.6
cfg.trainer.opt.beta1 = 0.9
cfg.trainer.opt.beta2 = 0.999
cfg.trainer.opt.momentum = 0.9 # for SGD
cfg.trainer.opt.weight_decay = 0.
cfg.trainer.opt.ema_decay = 0.9999
cfg.trainer.opt.grad_clip = -1.
cfg.trainer.opt.scheduler = ''
cfg.trainer.opt.step_decay = 0.998
cfg.trainer.opt.vae_lr_warmup_epochs = 0
cfg.trainer.anneal_kl = 0
cfg.trainer.kl_balance = 0
cfg.trainer.rec_balance = 0
cfg.trainer.loss1_weight_anneal_v = 'quad'
cfg.trainer.kl_ratio = [1.0, 1.0]
cfg.trainer.kl_ratio_apply = 0 # apply the fixed kl ratio in the kl_ratio list
# using spectral norm regularization on vae training or not (used in hvae_trainer)
cfg.trainer.sn_reg_vae = 0
cfg.trainer.sn_reg_vae_weight = 0.0 # loss weight for the sn regulatrization
# [start] set in runtime
cfg.log_name = ''
cfg.save_dir = ''
cfg.log_dir = ''
cfg.comet_key = ''
# [end]
cfg.voxel2pts = CN()
cfg.voxel2pts.init_weight = ''
cfg.voxel2pts.diffusion_steps = [0]
cfg.dpm = CN()
cfg.dpm.train_encoder_only = 0
cfg.num_ref = 0 # manully set the number of reference
cfg.eval_ddim_step = 0 # ddim sampling for the model evaluation
cfg.model_config = '' # used for model control, without ading new flag
## --- depreciated --- #
cfg.register_deprecated_key('cls') # CN()
cfg.register_deprecated_key('cls.classifier_type') # 'models.classifier.OneLayer'
cfg.register_deprecated_key('cls.train_on_eps') # 1
cfg.register_deprecated_key('cond_prior') # CN()
cfg.register_deprecated_key('cond_prior.grid_emb_resolution') # 32
cfg.register_deprecated_key('cond_prior.emb_dim') # 64
cfg.register_deprecated_key('cond_prior.use_voxel_feat') # 1
cfg.register_deprecated_key('cond_encoder_prior') # 'models.shapelatent_modules.VoxelGridEncoder'
cfg.register_deprecated_key('cond_prior.pvcconv_concat_3d_feat_input') # 0
cfg.register_deprecated_key('generate_mode_global') # 'interpolate'
cfg.register_deprecated_key('generate_mode_local') # 'freeze'
cfg.register_deprecated_key('normals') # CN()
cfg.register_deprecated_key('normals.model_type') # ''
cfg.register_deprecated_key('save_sample_seq_and_quit') # 0
cfg.register_deprecated_key('lns_loss_weight') # 1.0
cfg.register_deprecated_key('normal_pred_checkpoint') # ''
cfg.register_deprecated_key('lns') # CN()
cfg.register_deprecated_key('lns.override_config') # ''
cfg.register_deprecated_key('lns.wandb_checkpoint') # 'nvidia-toronto/generative_chairs/3m3gc6sz/checkpoint-171.pth'
cfg.register_deprecated_key('lns.num_input_points') # 1000
cfg.register_deprecated_key('lns.num_simulate') # 20
cfg.register_deprecated_key('lns.split_simulate') # 'train'
# use mesh-trainer or not
cfg.register_deprecated_key('with_lns') # 0
cfg.register_deprecated_key('normal_predictor_yaml') # ''
cfg.register_deprecated_key('pointtransformer') # CN()
# number of attention layer in each block
cfg.register_deprecated_key('pointtransformer.blocks') # [2, 3, 4, 6, 3]
cfg.register_deprecated_key('shapelatent.refiner_bp') # 1 # bp gradient to the local-decoder or not
cfg.register_deprecated_key('shapelatent.loss_weight_refiner') # 1.0 # weighted loss for the refiner
cfg.register_deprecated_key('shapelatent.refiner_type') # 'models.pvcnn2.PVCNN2BaseAPI' # mode for the refiner
cfg.register_deprecated_key('shapelatent.encoder_weight_std') # 0.1
cfg.register_deprecated_key('shapelatent.encoder_weight_norm') # 0
cfg.register_deprecated_key('shapelatent.encoder_weight_uniform') # 1
cfg.register_deprecated_key('shapelatent.key_point_gen') # 'mlps'
cfg.register_deprecated_key('shapelatent.add_sub_loss') # 1 # not used
cfg.register_deprecated_key('shapelatent.local_decoder_type') # ''
cfg.register_deprecated_key('shapelatent.local_decoder_type_1') # ''
cfg.register_deprecated_key('shapelatent.local_encoder_ball_radius') # 0.8
cfg.register_deprecated_key('shapelatent.local_encoder_ap_ball_radius') # 1.0
cfg.register_deprecated_key('shapelatent.local_encoder_type') # ''
cfg.register_deprecated_key('shapelatent.local_encoder_type_1') # ''
cfg.register_deprecated_key('shapelatent.local_loss_weight_max') # 50
cfg.register_deprecated_key('shapelatent.num_neighbors') # 0
cfg.register_deprecated_key('shapelatent.extra_centers') # []
# for latent model is flow
cfg.register_deprecated_key('shapelatent.latent_flow_depth') # 14
cfg.register_deprecated_key('shapelatent.latent_flow_hidden_dim') # 256
cfg.register_deprecated_key('shapelatent.bp_to_l0') # True
cfg.register_deprecated_key('shapelatent.global_only_epochs') # 0
cfg.register_deprecated_key('shapelatent.center_local_points') # 1
cfg.register_deprecated_key('shapelatent.hvae') # CN()
# alternatively way to compute the local loss
cfg.register_deprecated_key('shapelatent.hvae.loss_wrt_ori') # 0
# add voxel feature to the latent space; the decoder require pvc conv or query
cfg.register_deprecated_key('shapelatent.add_voxel2z_global') # 0
# reuse the encoder to get local latent
cfg.register_deprecated_key('shapelatent.query_output_local_from_enc') # 0
# check models/shapelatent_modules where the feature will be saved as a dict
cfg.register_deprecated_key('shapelatent.query_local_feat_layer') # 'inter_voxelfeat_0'
# need to check the sa_blocks of the global encoder
cfg.register_deprecated_key('shapelatent.query_local_feat_dim') # 32
# reuse the encoder to get local latent
cfg.register_deprecated_key('shapelatent.query_center_emd_from_enc') # 0 # reuse the encoder for center emd
cfg.register_deprecated_key('shapelatent.prog_dec_gf') # 8 # grow_factor in VaniDecoderProg
cfg.register_deprecated_key('shapelatent.prog_dec_gf_list') # [0, 0] # grow_factor in VaniDecoderProg
cfg.register_deprecated_key('shapelatent.prog_dec_ne') # 2 # num_expand in VaniDecoderProg
# increase number hirach, used by hvaemul model
cfg.register_deprecated_key('shapelatent.num_neighbors_per_level') # [64] # number of neighbors for each level
cfg.register_deprecated_key('shapelatent.num_level') # 1 # number of hierarchi latent space (local)
cfg.register_deprecated_key('shapelatent.x0_target_fps') # 0 # let the target of global output as the
cfg.register_deprecated_key('shapelatent.downsample_input_ratio') # 1.0
# whether taking other tensor as input to local-encoder of not
cfg.register_deprecated_key('shapelatent.local_enc_input') # 'sim'
# local encoder take z0 as input at which location
cfg.register_deprecated_key('shapelatent.local_encoder_condition_z0') # ''
# output the absolution coordinates or the offset w.r.t centers
cfg.register_deprecated_key('shapelatent.local_decoder_output_offset') # 0
# feed coords of keypoints to the local prior model
cfg.register_deprecated_key('shapelatent.local_prior_need_coords') # 0
# add the time embedding tensor to each encoder layer instead of add to first layer only
cfg.register_deprecated_key('sde.transformer_temb2interlayer') # 0
# normalization used in transformer encoder;
cfg.register_deprecated_key('sde.transformer_norm_type') # 'layer_norm'
cfg.register_deprecated_key('data.has_normal') # 0 # for datasets/pointflow_rgb.py only
cfg.register_deprecated_key('data.has_color') # 0 # for datasets/pointflow_rgb.py only
cfg.register_deprecated_key('data.cls_data_ratio') # 1.0 # ratio of the training data
cfg.register_deprecated_key('data.sample_curvature') # 0 # only for datasets/shape_curvature
cfg.register_deprecated_key('data.ratio_c') # 1.0 # only for datasets/shape_curvature

45
demo.py Normal file
View file

@ -0,0 +1,45 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
require diffusers-0.11.1
"""
import os
import clip
import torch
from PIL import Image
from default_config import cfg as config
from models.lion import LION
from utils.vis_helper import plot_points
from huggingface_hub import hf_hub_download
model_path = './lion_ckpt/text2shape/chair/checkpoints/model.pt'
model_config = './lion_ckpt/text2shape/chair/cfg.yml'
config.merge_from_file(model_config)
lion = LION(config)
lion.load_model(model_path)
if config.clipforge.enable:
input_t = ["a swivel chair, five wheels"]
device_str = 'cuda'
clip_model, clip_preprocess = clip.load(
config.clipforge.clip_model, device=device_str)
text = clip.tokenize(input_t).to(device_str)
clip_feat = []
clip_feat.append(clip_model.encode_text(text).float())
clip_feat = torch.cat(clip_feat, dim=0)
print('clip_feat', clip_feat.shape)
else:
clip_feat = None
output = lion.sample(1 if clip_feat is None else clip_feat.shape[0], clip_feat=clip_feat)
pts = output['points']
img_name = "/tmp/tmp.png"
plot_points(pts, output_name=img_name)
img = Image.open(img_name)
img.show()

311
env.yaml Normal file
View file

@ -0,0 +1,311 @@
name: lion_env
channels:
- pytorch
- nvidia
- anaconda
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- argon2-cffi=20.1.0=py38h27cfd23_1
- async_generator=1.10=pyhd3eb1b0_0
- attrs=21.4.0=pyhd3eb1b0_0
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bleach=4.1.0=pyhd3eb1b0_0
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2020.10.14=0
- certifi=2020.6.20=py38_0
- cffi=1.15.0=py38hd667e15_1
- cmake=3.18.2=ha30ef3c_0
- cudatoolkit=11.1.74=h6bb024c_0
- debugpy=1.5.1=py38h295c915_0
- decorator=5.1.1=pyhd3eb1b0_0
- defusedxml=0.7.1=pyhd3eb1b0_0
- entrypoints=0.3=py38_0
- expat=2.2.10=he6710b0_2
- ffmpeg=4.3=hf484d3e_0
- freetype=2.11.0=h70c0345_0
- giflib=5.2.1=h7b6447c_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- importlib_metadata=4.8.2=hd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- ipykernel=6.4.1=py38h06a4308_1
- ipython=7.31.1=py38h06a4308_0
- ipython_genutils=0.2.0=pyhd3eb1b0_1
- ipywidgets=7.6.5=pyhd3eb1b0_1
- jedi=0.18.1=py38h06a4308_0
- jpeg=9d=h7f8727e_0
- jupyter_client=7.1.2=pyhd3eb1b0_0
- jupyter_core=4.9.1=py38h06a4308_0
- jupyterlab_pygments=0.1.2=py_0
- jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
- krb5=1.18.2=h173b8e3_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libcurl=7.71.1=h20c2e04_1
- libedit=3.1.20191231=h14c3975_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.2=h7f8727e_0
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libssh2=1.9.0=h1ba5d50_1
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.0=h89dd481_0
- libwebp-base=1.2.0=h27cfd23_0
- lz4-c=1.9.3=h295c915_1
- markupsafe=2.0.1=py38h27cfd23_0
- matplotlib-inline=0.1.2=pyhd3eb1b0_2
- mistune=0.8.4=py38h7b6447c_1000
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- nbclient=0.5.3=pyhd3eb1b0_0
- nbconvert=6.3.0=py38h06a4308_0
- ncurses=6.3=h7f8727e_2
- nest-asyncio=1.5.1=pyhd3eb1b0_0
- nettle=3.7.3=hbbd107a_1
- notebook=6.4.6=py38h06a4308_0
- numpy=1.21.2=py38h20f2e39_0
- numpy-base=1.21.2=py38h79a1101_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1m=h7f8727e_0
- packaging=21.3=pyhd3eb1b0_0
- pandocfilters=1.5.0=pyhd3eb1b0_0
- parso=0.8.3=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py38h5aabda8_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
- prometheus_client=0.13.1=pyhd3eb1b0_0
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pycparser=2.21=pyhd3eb1b0_0
- pygments=2.11.2=pyhd3eb1b0_0
- python=3.8.12=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-fastjsonschema=2.16.1=pyhd8ed1ab_0
- python_abi=3.8=2_cp38
- pytorch=1.10.2=py3.8_cuda11.1_cudnn8.0.5_0
- pytorch-mutex=1.0=cuda
- pyzmq=22.3.0=py38h295c915_2
- readline=8.1.2=h7f8727e_1
- rhash=1.4.0=h1ba5d50_0
- send2trash=1.8.0=pyhd3eb1b0_1
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.37.2=hc218d9a_0
- terminado=0.9.4=py38h06a4308_0
- testpath=0.5.0=pyhd3eb1b0_0
- tk=8.6.11=h1ccaba5_0
- torchaudio=0.10.2=py38_cu111
- torchvision=0.11.3=py38_cu111
- tornado=6.1=py38h27cfd23_0
- traitlets=5.1.1=pyhd3eb1b0_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- webencodings=0.5.1=py38_1
- wheel=0.37.1=pyhd3eb1b0_0
- widgetsnbextension=3.5.1=py38_0
- xz=5.2.5=h7b6447c_0
- zeromq=4.3.4=h2531618_0
- zipp=3.7.0=pyhd3eb1b0_0
- zlib=1.2.11=h7f8727e_4
- zstd=1.4.5=h9ceee32_0
- pip:
- about-time==3.1.1
- absl-py==1.0.0
- addict==2.4.0
- aiohttp==3.8.1
- aiosignal==1.2.0
- alive-progress==2.2.0
- antlr4-python3-runtime==4.9.3
- anyio==3.5.0
- astunparse==1.6.3
- async-timeout==4.0.2
- babel==2.9.1
- cachetools==5.0.0
- calmsize==0.1.3
- ccimport==0.3.7
- cftime==1.6.0
- charset-normalizer==2.0.11
- click==8.0.3
- colorama==0.4.4
- comet-ml==3.31.21
- commonmark==0.9.1
- configobj==5.0.6
- crc32c==2.2.post0
- cumm-cu111==0.2.8
- cupy-cuda111==10.2.0
- cycler==0.11.0
- cython==0.29.20
- dataclasses==0.6
- deepspeed==0.6.5
- deprecation==2.1.0
- diffusers==0.11.1
- docker-pycreds==0.4.0
- drjit==0.2.1
- dulwich==0.20.32
- easydict==1.9
- einops==0.4.0
- everett==3.0.0
- fastrlock==0.8
- filelock==3.9.0
- fire==0.4.0
- flatbuffers==2.0
- flatten-dict==0.4.2
- fonttools==4.29.1
- freetype-py==2.3.0
- frozenlist==1.3.0
- fsspec==2022.2.0
- ftfy==6.1.1
- future==0.18.2
- fvcore==0.1.5.post20220512
- gast==0.5.3
- gitdb==4.0.9
- gitpython==3.1.26
- google-auth==2.6.0
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- grapheme==0.6.0
- grpcio==1.43.0
- h5py==3.6.0
- hjson==3.0.2
- huggingface-hub==0.11.1
- idna==3.3
- imageio==2.15.0
- imageio-ffmpeg==0.4.5
- importlib-metadata==4.10.1
- importlib-resources==5.4.0
- iopath==0.1.10
- jinja2==3.1.1
- joblib==1.1.0
- json5==0.9.6
- jsonschema==4.4.0
- jupyter-packaging==0.12.0
- jupyter-server==1.15.6
- jupyterlab==3.3.2
- jupyterlab-server==2.11.2
- keras==2.8.0
- keras-preprocessing==1.1.2
- kiwisolver==1.3.2
- kornia==0.6.6
- lark==1.1.2
- libclang==14.0.1
- llvmlite==0.39.0
- loguru==0.6.0
- markdown==3.3.6
- matplotlib==3.5.1
- matplotlib2tikz==0.7.6
- meshio==5.3.4
- mitsuba==3.0.1
- mrcfile==1.3.0
- multidict==6.0.2
- multipledispatch==0.6.0
- mypy-extensions==0.4.3
- nbclassic==0.3.7
- nbformat==5.2.0
- nestargs==0.5.0
- netcdf4==1.5.8
- networkx==2.6.3
- ninja==1.10.2.3
- notebook-shim==0.1.0
- numba==0.56.0
- nvidia-ml-py3==7.352.0
- oauthlib==3.2.0
- omegaconf==2.2.2
- open3d==0.15.2
- opencv-python==4.5.5.64
- openexr==1.3.7
- opt-einsum==3.3.0
- pandas==1.4.0
- pathtools==0.1.2
- pccm==0.3.4
- pip==22.3.1
- plyfile==0.7.4
- portalocker==2.5.1
- progressbar2==4.0.0
- promise==2.3
- protobuf==3.19.4
- psutil==5.9.0
- py-cpuinfo==8.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pybind11==2.10.0
- pydeprecate==0.3.1
- pyglet==1.5.23
- pykeops==1.5
- pymcubes==0.1.2
- pyopengl==3.1.0
- pyparsing==3.0.7
- pyquaternion==0.9.9
- pyrr==0.10.3
- pyrsistent==0.18.1
- python-swiftclient==4.0.0
- python-utils==3.3.3
- pytorch-lightning==1.5.1
- pytorch3d==0.3.0
- pytz==2021.3
- pywavelets==1.2.0
- pyyaml==6.0
- regex==2022.3.15
- requests==2.27.1
- requests-oauthlib==1.3.1
- requests-toolbelt==0.9.1
- rich==12.3.0
- rsa==4.8
- ruamel-yaml==0.17.20
- ruamel-yaml-clib==0.2.6
- scikit-image==0.19.1
- scikit-learn==1.0.2
- scipy==1.8.0
- seaborn==0.11.2
- semantic-version==2.9.0
- sentry-sdk==1.5.4
- sharedarray==3.2.1
- shortuuid==1.0.8
- simple-parsing==0.0.18
- simplejson==3.18.0
- sklearn==0.0
- smmap==5.0.0
- sniffio==1.2.0
- tabulate==0.8.9
- tensorboard==2.8.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tensorboardx==2.4.1
- tensorflow-gpu==2.8.0
- tensorflow-io-gcs-filesystem==0.25.0
- termcolor==1.1.0
- tf-estimator-nightly==2.8.0.dev2021122109
- tflearn==0.5.0
- tfrecord==1.14.1
- threadpoolctl==3.1.0
- tifffile==2022.2.2
- tikzplotlib==0.10.1
- tomlkit==0.10.0
- torchmetrics==0.7.2
- tqdm==4.62.3
- trimesh==3.10.1
- typing-extensions==4.2.0
- typing-inspect==0.7.1
- urllib3==1.26.8
- wandb==0.12.10
- webcolors==1.11.1
- websocket-client==1.2.3
- werkzeug==2.0.3
- wrapt==1.13.3
- wurlitzer==3.0.2
- yacs==0.1.8
- yarl==1.7.2
- yaspin==2.1.0

67
models/adagn.py Normal file
View file

@ -0,0 +1,67 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
adaptive group norm
"""
from loguru import logger
import torch.nn as nn
import torch
import numpy as np
from utils.checker import *
from .dense import dense
import os
class AdaGN(nn.Module):
'''
adaptive group normalization
'''
def __init__(self, ndim, cfg, n_channel):
"""
ndim: dim of the input features
n_channel: number of channels of the inputs
ndim_style: channel of the style features
"""
super().__init__()
style_dim = cfg.latent_pts.style_dim
init_scale = cfg.latent_pts.ada_mlp_init_scale
self.ndim = ndim
self.n_channel = n_channel
self.style_dim = style_dim
self.out_dim = n_channel * 2
self.norm = nn.GroupNorm(8, n_channel)
in_channel = n_channel
self.emd = dense(style_dim, n_channel*2, init_scale=init_scale)
self.emd.bias.data[:in_channel] = 1
self.emd.bias.data[in_channel:] = 0
def __repr__(self):
return f"AdaGN(GN(8, {self.n_channel}), Linear({self.style_dim}, {self.out_dim}))"
def forward(self, image, style):
# style: B,D
# image: B,D,N,1
CHECK2D(style)
style = self.emd(style)
if self.ndim == 3: #B,D,V,V,V
CHECK5D(image)
style = style.view(style.shape[0], -1, 1, 1, 1) # 5D
elif self.ndim == 2: # B,D,N,1
CHECK4D(image)
style = style.view(style.shape[0], -1, 1, 1) # 4D
elif self.ndim == 1: # B,D,N
CHECK3D(image)
style = style.view(style.shape[0], -1, 1) # 4D
else:
raise NotImplementedError
factor, bias = style.chunk(2, 1)
result = self.norm(image)
result = result * factor + bias
return result

80
models/dense.py Normal file
View file

@ -0,0 +1,80 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" copied and modified from https://github.com/CW-Huang/sdeflow-light/blob/524650bc5ad69522b3e0905672deef0650374512/lib/models/unet.py """
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import _calculate_fan_in_and_fan_out
import numpy as np
def _calculate_correct_fan(tensor, mode):
"""
copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L337
"""
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out', 'fan_avg']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(tensor, gain=1., mode='fan_in'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: multiplier to the dispersion
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in')
"""
fan = _calculate_correct_fan(tensor, mode)
# gain = calculate_gain(nonlinearity, a)
var = gain / max(1., fan)
bound = math.sqrt(3.0 * var) # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def variance_scaling_init_(tensor, scale):
return kaiming_uniform_(tensor, gain=1e-10 if scale == 0 else scale, mode='fan_avg')
def dense(in_channels, out_channels, init_scale=1.):
lin = nn.Linear(in_channels, out_channels)
variance_scaling_init_(lin.weight, scale=init_scale)
nn.init.zeros_(lin.bias)
return lin
def conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=1, dilation=1, padding=1, bias=True, padding_mode='zeros',
init_scale=1.):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
bias=bias, padding_mode=padding_mode)
variance_scaling_init_(conv.weight, scale=init_scale)
if bias:
nn.init.zeros_(conv.bias)
return conv

37
models/distributions.py Normal file
View file

@ -0,0 +1,37 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import numpy as np
@torch.jit.script
def sample_normal_jit(mu, sigma):
rho = mu.mul(0).normal_()
z = rho.mul_(sigma).add_(mu)
return z, rho
class Normal:
def __init__(self, mu, log_sigma, sigma=None):
self.mu = mu
self.log_sigma = log_sigma
self.sigma = torch.exp(log_sigma) if sigma is None else sigma
def sample(self, t=1.):
return sample_normal_jit(self.mu, self.sigma * t)
def sample_given_rho(self, rho):
return rho * self.sigma + self.mu
def mean(self):
return self.mu
def log_p(self, samples):
normalized_samples = (samples - self.mu) / self.sigma
log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma
return log_p

273
models/latent_points_ada.py Normal file
View file

@ -0,0 +1,273 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
from loguru import logger
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .pvcnn2_ada import \
create_pointnet2_sa_components, create_pointnet2_fp_modules, LinearAttention, create_mlp_components, SharedMLP
# the building block of encode and decoder for VAE
class PVCNN2Unet(nn.Module):
"""
copied and modified from https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py#L172
"""
def __init__(self,
num_classes, embed_dim, use_att, dropout=0.1,
extra_feature_channels=3,
input_dim=3,
width_multiplier=1,
voxel_resolution_multiplier=1,
time_emb_scales=1.0,
verbose=True,
condition_input=False,
point_as_feat=1, cfg={},
sa_blocks={}, fp_blocks={},
clip_forge_enable=0,
clip_forge_dim=512
):
super().__init__()
logger.info('[Build Unet] extra_feature_channels={}, input_dim={}',
extra_feature_channels, input_dim)
self.input_dim = input_dim
self.clip_forge_enable = clip_forge_enable
self.sa_blocks = sa_blocks
self.fp_blocks = fp_blocks
self.point_as_feat = point_as_feat
self.condition_input = condition_input
assert extra_feature_channels >= 0
self.time_emb_scales = time_emb_scales
self.embed_dim = embed_dim
## assert(self.embed_dim == 0)
if self.embed_dim > 0: # has time embedding
# for prior model, we have time embedding, for VAE model, no time embedding
self.embedf = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.LeakyReLU(0.1, inplace=True),
nn.Linear(embed_dim, embed_dim),
)
if self.clip_forge_enable:
self.clip_forge_mapping = nn.Linear(clip_forge_dim, embed_dim)
style_dim = cfg.latent_pts.style_dim
self.style_clip = nn.Linear(style_dim + embed_dim, style_dim)
self.in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels, channels_sa_features, _ = \
create_pointnet2_sa_components(
input_dim=input_dim,
sa_blocks=self.sa_blocks,
extra_feature_channels=extra_feature_channels,
with_se=True,
embed_dim=embed_dim, # time embedding dim
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier,
verbose=verbose, cfg=cfg
)
self.sa_layers = nn.ModuleList(sa_layers)
self.global_att = None if not use_att else LinearAttention(channels_sa_features, 8, verbose=verbose)
# only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels + input_dim - 3
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks, in_channels=channels_sa_features,
sa_in_channels=sa_in_channels,
with_se=True, embed_dim=embed_dim,
use_att=use_att, dropout=dropout,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier,
verbose=verbose, cfg=cfg
)
self.fp_layers = nn.ModuleList(fp_layers)
layers, _ = create_mlp_components(
in_channels=channels_fp_features,
out_channels=[128, dropout, num_classes], # was 0.5
classifier=True, dim=2, width_multiplier=width_multiplier,
cfg=cfg)
self.classifier = nn.ModuleList(layers)
def get_timestep_embedding(self, timesteps, device):
if len(timesteps.shape) == 2 and timesteps.shape[1] == 1:
timesteps = timesteps[:,0]
assert(len(timesteps.shape) == 1), f'get shape: {timesteps.shape}'
timesteps = timesteps * self.time_emb_scales
half_dim = self.embed_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if self.embed_dim % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim])
return emb
def forward(self, inputs, **kwargs):
# Input: coords: B3N
B = inputs.shape[0]
coords = inputs[:, :self.input_dim, :].contiguous()
features = inputs
temb = kwargs.get('t', None)
if temb is not None:
t = temb
if t.ndim == 0 and not len(t.shape) == 1:
t = t.view(1).expand(B)
temb = self.embedf(self.get_timestep_embedding(t, inputs.device
))[:,:,None].expand(-1,-1,inputs.shape[-1])
temb_ori = temb # B,embed_dim,Npoint
style = kwargs['style']
if self.clip_forge_enable:
clip_feat = kwargs['clip_feat']
assert(clip_feat is not None), f'require clip_feat as input'
clip_feat = self.clip_forge_mapping(clip_feat)
style = torch.cat([style, clip_feat], dim=1).contiguous()
style = self.style_clip(style)
coords_list, in_features_list = [], []
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i > 0 and temb is not None:
#TODO: implement a sa_blocks forward function; check if is PVConv layer and kwargs get grid_emb, take as additional input
features = torch.cat([features,temb],dim=1)
features, coords, temb, _ = \
sa_blocks ((features,
coords, temb, style))
else: # i == 0 or temb is None
features, coords, temb, _ = \
sa_blocks ((features, coords, temb, style))
in_features_list[0] = inputs[:, 3:, :].contiguous()
if self.global_att is not None:
features = self.global_att(features)
for fp_idx, fp_blocks in enumerate(self.fp_layers):
if temb is not None:
features, coords, temb, _ = fp_blocks((
coords_list[-1-fp_idx], coords,
torch.cat([features,temb],dim=1),
in_features_list[-1-fp_idx], temb, style))
else:
features, coords, temb, _ = fp_blocks((
coords_list[-1-fp_idx], coords,
features,
in_features_list[-1-fp_idx], temb, style))
for l in self.classifier:
if isinstance(l, SharedMLP):
features = l(features, style)
else:
features = l(features)
return features
class PointTransPVC(nn.Module):
# encoder : B,N,3 -> B,N,2*D
sa_blocks = [ # conv_configs, sa_configs
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (128, 128, 128))),
]
fp_blocks = [
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
((128, 128), (128, 3, 8)),
((128, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, zdim, input_dim, args={}):
super().__init__()
self.zdim = zdim
self.layers = PVCNN2Unet(2*zdim+input_dim*2,
embed_dim=0, use_att=1, extra_feature_channels=0,
input_dim=args.ddpm.input_dim, cfg=args,
sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks,
dropout=args.ddpm.dropout)
self.skip_weight = args.latent_pts.skip_weight
self.pts_sigma_offset = args.latent_pts.pts_sigma_offset
self.input_dim = input_dim
def forward(self, inputs):
x, style = inputs
B,N,D = x.shape
output = self.layers(x.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BND
pt_mu_1d = output[:,:,:self.input_dim].contiguous()
pt_sigma_1d = output[:,:,self.input_dim:2*self.input_dim].contiguous() - self.pts_sigma_offset
pt_mu_1d = self.skip_weight * pt_mu_1d + x
if self.zdim > 0:
ft_mu_1d = output[:,:,2*self.input_dim:-self.zdim].contiguous()
ft_sigma_1d = output[:,:,-self.zdim:].contiguous()
mu_1d = torch.cat([pt_mu_1d, ft_mu_1d], dim=2).view(B,-1).contiguous()
sigma_1d = torch.cat([pt_sigma_1d, ft_sigma_1d], dim=2).view(B,-1).contiguous()
else:
mu_1d = pt_mu_1d.view(B,-1).contiguous()
sigma_1d = pt_sigma_1d.view(B,-1).contiguous()
return {'mu_1d': mu_1d, 'sigma_1d': sigma_1d}
class LatentPointDecPVC(nn.Module):
""" input x: [B,Npoint,D] with [B,Npoint,3]
"""
sa_blocks = [ # conv_configs, sa_configs
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (128, 128, 128))),
]
fp_blocks = [
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
((128, 128), (128, 3, 8)),
((128, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, point_dim, context_dim, num_points=None, args={}, **kwargs):
super().__init__()
self.point_dim = point_dim
logger.info('[Build Dec] point_dim={}, context_dim={}', point_dim, context_dim)
self.context_dim = context_dim + self.point_dim
# self.num_points = num_points
if num_points is None:
self.num_points = args.data.tr_max_sample_points
else:
self.num_points = num_points
self.layers = PVCNN2Unet(point_dim, embed_dim=0, use_att=1,
extra_feature_channels=context_dim,
input_dim=args.ddpm.input_dim, cfg=args,
sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks,
dropout=args.ddpm.dropout)
self.skip_weight = args.latent_pts.skip_weight
def forward(self, x, beta, context, style):
"""
Args:
x: Point clouds at some timestep t, (B, N, d). [not used]
beta: Time. (B, ). [not used]
context: Latent points, (B,N_pts*D_latent_pts), D_latent_pts = D_input + D_extra
style: Shape latents. (B,d).
Returns:
points: (B,N,3)
"""
# CHECKDIM(context, 1, self.num_points*self.context_dim)
assert(context.shape[1] == self.num_points*self.context_dim)
context = context.view(-1,self.num_points,self.context_dim) # BND
x = context[:,:,:self.point_dim]
output = self.layers(context.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BN3
output = output * self.skip_weight + x
return output

View file

@ -0,0 +1,84 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
from loguru import logger
import torch.nn as nn
import torch.nn.functional as F
from .latent_points_ada import PVCNN2Unet
from .utils import mask_inactive_variables
# diffusion model for latent points
class PVCNN2Prior(PVCNN2Unet):
sa_blocks = [ # conv_configs, sa_configs
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 128))),
(None, (16, 0.8, 32, (128, 128, 128))),
]
fp_blocks = [
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
((128, 128), (128, 3, 8)),
((128, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, args, num_input_channels, cfg):
# only cfg is used
self.clip_forge_enable = cfg.clipforge.enable
clip_forge_dim = cfg.clipforge.feat_dim
num_input_channels = num_classes = cfg.shapelatent.latent_dim + cfg.ddpm.input_dim
self.num_classes = num_classes
embed_dim = cfg.ddpm.time_dim
use_att = True
extra_feature_channels = cfg.shapelatent.latent_dim
self.num_points = cfg.data.tr_max_sample_points
dropout = cfg.ddpm.dropout
time_emb_scales = cfg.sde.embedding_scale # 1k default
logger.info('[Build Prior Model] nclass={}, embed_dim={}, use_att={},'
'extra_feature_channels={}, dropout={}, time_emb_scales={} num_point={}',
num_classes, embed_dim, use_att, extra_feature_channels, dropout, time_emb_scales,
self.num_points)
# Attention: we are not using time_emb_scales here, but the embedding_scale
super().__init__(
num_classes, embed_dim, use_att, dropout=dropout,
input_dim=cfg.ddpm.input_dim,
extra_feature_channels=extra_feature_channels,
time_emb_scales=time_emb_scales,
verbose=True,
condition_input=False,
cfg=cfg,
sa_blocks=self.sa_blocks,
fp_blocks=self.fp_blocks,
clip_forge_enable=self.clip_forge_enable, clip_forge_dim=clip_forge_dim)
# init mixing logit
self.mixed_prediction = cfg.sde.mixed_prediction # This enables mixed prediction
if self.mixed_prediction:
logger.info('init-mixing_logit = {}, after sigmoid = {}',
cfg.sde.mixing_logit_init, torch.sigmoid(torch.tensor(cfg.sde.mixing_logit_init))
)
init = cfg.sde.mixing_logit_init * torch.ones(size=[1, num_input_channels*self.num_points, 1, 1])
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
self.is_active = None
else: # no mixing_logit
self.mixing_logit = None
self.is_active = None
def forward(self, x, t, *args, **kwargs): #x0=None):
# Input: x: B,ND or B,ND,1,1
# require shape for x: B,C,N
## CHECKEQ(x.shape[-1], self.num_classes)
assert('condition_input' in kwargs), 'require condition_input'
if self.mixed_prediction and self.is_active is not None:
x = mask_inactive_variables(x, self.is_active)
input_shape = x.shape
x = x.view(-1,self.num_points,self.num_classes).permute(0,2,1).contiguous()
B = x.shape[0]
out = super().forward(x, t=t, style=kwargs['condition_input'].squeeze(-1).squeeze(-1), clip_feat=kwargs.get('clip_feat', None))
return out.permute(0,2,1).contiguous().view(input_shape)
# -1,self.num_classes) # BDN -> BND -> BN,D

91
models/lion.py Normal file
View file

@ -0,0 +1,91 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
from models.vae_adain import Model as VAE
from models.latent_points_ada_localprior import PVCNN2Prior as LocalPrior
from utils.diffusion_pvd import DiffusionDiscretized
from utils.vis_helper import plot_points
from utils.model_helper import import_model
from diffusers import DDPMScheduler
import torch
from matplotlib import pyplot as plt
class LION(object):
def __init__(self, cfg):
self.vae = VAE(cfg).cuda()
GlobalPrior = import_model(cfg.latent_pts.style_prior)
global_prior = GlobalPrior(cfg.sde, cfg.latent_pts.style_dim, cfg).cuda()
local_prior = LocalPrior(cfg.sde, cfg.shapelatent.latent_dim, cfg).cuda()
self.priors = torch.nn.ModuleList([global_prior, local_prior])
self.scheduler = DDPMScheduler(clip_sample=False,
beta_start=cfg.ddpm.beta_1, beta_end=cfg.ddpm.beta_T, beta_schedule=cfg.ddpm.sched_mode,
num_train_timesteps=cfg.ddpm.num_steps, variance_type=cfg.ddpm.model_var_type)
self.diffusion = DiffusionDiscretized(None, None, cfg)
# self.load_model(cfg)
def load_model(self, model_path):
# model_path = cfg.ckpt.path
ckpt = torch.load(model_path)
self.priors.load_state_dict(ckpt['dae_state_dict'])
self.vae.load_state_dict(ckpt['vae_state_dict'])
print(f'INFO finish loading from {model_path}')
@torch.no_grad()
def sample(self, num_samples=10, clip_feat=None, save_img=False):
self.scheduler.set_timesteps(1000, device='cuda')
timesteps = self.scheduler.timesteps
latent_shape = self.vae.latent_shape()
global_prior, local_prior = self.priors[0], self.priors[1]
assert(not local_prior.mixed_prediction and not global_prior.mixed_prediction)
sampled_list = []
output_dict = {}
# start sample global prior
x_T_shape = [num_samples] + latent_shape[0]
x_noisy = torch.randn(size=x_T_shape, device='cuda')
condition_input = None
for i, t in enumerate(timesteps):
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
noise_pred = global_prior(x=x_noisy, t=t_tensor.float(),
condition_input=condition_input, clip_feat=clip_feat)
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
sampled_list.append(x_noisy)
output_dict['z_global'] = x_noisy
condition_input = x_noisy
condition_input = self.vae.global2style(condition_input)
# start sample local prior
x_T_shape = [num_samples] + latent_shape[1]
x_noisy = torch.randn(size=x_T_shape, device='cuda')
for i, t in enumerate(timesteps):
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
noise_pred = local_prior(x=x_noisy, t=t_tensor.float(),
condition_input=condition_input, clip_feat=clip_feat)
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
sampled_list.append(x_noisy)
output_dict['z_local'] = x_noisy
# decode the latent
output = self.vae.sample(num_samples=num_samples, decomposed_eps=sampled_list)
if save_img:
out_name = plot_points(output, "/tmp/tmp.png")
print(f'INFO save plot image at {out_name}')
output_dict['points'] = output
return output_dict
def get_mixing_component(self, noise_pred, t):
# usage:
# if global_prior.mixed_prediction:
# mixing_component = self.get_mixing_component(noise_pred, t)
# coeff = torch.sigmoid(global_prior.mixing_logit)
# noise_pred = (1 - coeff) * mixing_component + coeff * noise_pred
alpha_bar = self.scheduler.alphas_cumprod[t]
one_minus_alpha_bars_sqrt = np.sqrt(1.0 - alpha_bar)
return noise_pred * one_minus_alpha_bars_sqrt

557
models/pvcnn2.py Normal file
View file

@ -0,0 +1,557 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
copied and modified from source:
https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py
and functions under
https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules
"""
import copy
import functools
from loguru import logger
from einops import rearrange
import torch.nn as nn
import torch
import numpy as np
import third_party.pvcnn.functional as F
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
class SE3d(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
self.channel = channel
def __repr__(self):
return f"SE({self.channel}, {self.channel})"
def forward(self, inputs):
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)
class LinearAttention(nn.Module):
"""
copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159
"""
def __init__(self, dim, heads = 4, dim_head = 32, verbose=True):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
'''
Args:
x: torch.tensor (B,C,N), C=num-channels, N=num-points
Returns:
out: torch.tensor (B,C,N)
'''
x = x.unsqueeze(-1) # add w dimension
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = self.to_out(out)
out = out.squeeze(-1) # B,C,N,1 -> B,C,N
return out
def swish(input):
return input * torch.sigmoid(input)
class Swish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return swish(input)
class BallQuery(nn.Module):
def __init__(self, radius, num_neighbors, include_coordinates=True):
super().__init__()
self.radius = radius
self.num_neighbors = num_neighbors
self.include_coordinates = include_coordinates
@custom_bwd
def backward(self, *args, **kwargs):
return super().backward(*args, **kwargs)
@custom_fwd(cast_inputs=torch.float32)
def forward(self, points_coords, centers_coords, points_features=None):
# input: BCN, BCN
# returns:
# neighbor_features: B,D(+3),Ncenter
points_coords = points_coords.contiguous()
centers_coords = centers_coords.contiguous()
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
if points_features is None:
assert self.include_coordinates, 'No Features For Grouping'
neighbor_features = neighbor_coordinates
else:
neighbor_features = F.grouping(points_features, neighbor_indices)
if self.include_coordinates:
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
return neighbor_features
def extra_repr(self):
return 'radius={}, num_neighbors={}{}'.format(
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
class SharedMLP(nn.Module):
def __init__(self, in_channels, out_channels, dim=1):
super().__init__()
if dim==1:
conv = nn.Conv1d
else:
conv = nn.Conv2d
bn = nn.GroupNorm
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
layers = []
for oc in out_channels:
layers.append( conv(in_channels, oc, 1))
layers.append(bn(8, oc))
layers.append(Swish())
in_channels = oc
self.layers = nn.Sequential(*layers)
def forward(self, inputs):
if isinstance(inputs, (list, tuple)):
return (self.layers(inputs[0]), *inputs[1:])
else:
return self.layers(inputs)
class Voxelization(nn.Module):
def __init__(self, resolution, normalize=True, eps=0):
super().__init__()
self.r = int(resolution)
self.normalize = normalize
self.eps = eps
def forward(self, features, coords):
# features: B,D,N
# coords: B,3,N
coords = coords.detach()
norm_coords = coords - coords.mean(2, keepdim=True)
if self.normalize:
norm_coords = norm_coords / (norm_coords.norm(
dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 +
self.eps) + 0.5
else:
norm_coords = (norm_coords + 1) / 2.0
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
vox_coords = torch.round(norm_coords).to(torch.int32)
if features is None:
return features, norm_coords
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
def extra_repr(self):
return 'resolution={}{}'.format(
self.r,
', normalized eps = {}'.format(self.eps) if self.normalize else '')
class PVConv(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, resolution,
normalize=1, eps=0, with_se=False,
add_point_feat=True, attention=False,
dropout=0.1, verbose=True
):
super().__init__()
self.resolution = resolution
self.voxelization = Voxelization(resolution,
normalize=normalize,
eps=eps)
# For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention)
voxel_layers = [
nn.Conv3d(in_channels,
out_channels,
kernel_size, stride=1,
padding=kernel_size // 2),
nn.GroupNorm(8, out_channels),
Swish(),
nn.Dropout(dropout),
nn.Conv3d(out_channels, out_channels,
kernel_size, stride=1,
padding=kernel_size // 2),
nn.GroupNorm(8, out_channels)
]
if with_se:
voxel_layers.append(SE3d(out_channels))
self.voxel_layers = nn.Sequential(*voxel_layers)
if attention:
self.attn = LinearAttention(out_channels, verbose=verbose)
else:
self.attn = None
if add_point_feat:
self.point_features = SharedMLP(in_channels, out_channels) #, **mlp_kwargs)
self.add_point_feat = add_point_feat
def forward(self, inputs):
'''
Args:
inputs: tuple of features and coords
features: B,feat-dim,num-points
coords: B,3, num-points
Returns:
fused_features: in (B,out-feat-dim,num-points)
coords : in (B, 3, num_points); same as the input coords
'''
features = inputs[0]
coords_input = inputs[1]
time_emb = inputs[2]
## features, coords_input, time_emb = inputs
if coords_input.shape[1] > 3:
coords = coords_input[:,:3] # the last 3 dim are other point attributes if any
else:
coords = coords_input
assert (features.shape[0] == coords.shape[0]
), f'get feat: {features.shape} and {coords.shape}'
assert (features.shape[2] == coords.shape[2]
), f'get feat: {features.shape} and {coords.shape}'
assert (coords.shape[1] == 3
), f'expect coords: B,3,Npoint, get: {coords.shape}'
# features: B,D,N; point_features
# coords: B,3,N
voxel_features_4d, voxel_coords = self.voxelization(features, coords)
r = self.resolution
B = coords.shape[0]
voxel_features_4d = self.voxel_layers(voxel_features_4d)
voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords,
r, self.training)
fused_features = voxel_features
if self.add_point_feat:
fused_features = fused_features + self.point_features(features)
if self.attn is not None:
fused_features = self.attn(fused_features)
if time_emb is None:
time_emb = {'voxel_features_4d': voxel_features_4d, 'resolution': self.resolution, 'training': self.training}
return fused_features, coords_input, time_emb #inputs[2]
class PointNetAModule(nn.Module):
def __init__(self, in_channels, out_channels, include_coordinates=True):
super().__init__()
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]]
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels]
mlps = []
total_out_channels = 0
for _out_channels in out_channels:
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=1)
)
total_out_channels += _out_channels[-1]
self.include_coordinates = include_coordinates
self.out_channels = total_out_channels
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords, time_emb = inputs
if self.include_coordinates:
features = torch.cat([features, coords], dim=1)
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
if len(self.mlps) > 1:
features_list = []
for mlp in self.mlps:
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
return torch.cat(features_list, dim=1), coords, time_emb
else:
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords, time_emb
def extra_repr(self):
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
class PointNetSAModule(nn.Module):
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
super().__init__()
if not isinstance(radius, (list, tuple)):
radius = [radius]
if not isinstance(num_neighbors, (list, tuple)):
num_neighbors = [num_neighbors] * len(radius)
assert len(radius) == len(num_neighbors)
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]] * len(radius)
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels] * len(radius)
assert len(radius) == len(out_channels)
groupers, mlps = [], []
total_out_channels = 0
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
groupers.append(
BallQuery(radius=_radius, num_neighbors=_num_neighbors,
include_coordinates=include_coordinates)
)
# logger.info('create MLP: in_channel={}, out_channels={}',
# in_channels + (3 if include_coordinates else 0),_out_channels)
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0) ,
out_channels=_out_channels, dim=2)
)
total_out_channels += _out_channels[-1]
self.num_centers = num_centers
self.out_channels = total_out_channels
self.groupers = nn.ModuleList(groupers)
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
# features, coords, _ = inputs
features = inputs[0]
coords = inputs[1] # B3N
if coords.shape[1] > 3:
coords = coords[:,:3]
centers_coords = F.furthest_point_sample(coords, self.num_centers)
# centers_coords: B,D,N
S = centers_coords.shape[-1]
time_emb = inputs[2]
time_emb = time_emb[:,:,:S] if \
time_emb is not None and type(time_emb) is not dict \
else time_emb
features_list = []
c = 0
for grouper, mlp in zip(self.groupers, self.mlps):
c += 1
grouper_output = grouper(coords, centers_coords, features)
features_list.append(
mlp(grouper_output
).max(dim=-1).values
)
if len(features_list) > 1:
return torch.cat(features_list, dim=1), centers_coords, time_emb
else:
return features_list[0], centers_coords, time_emb
def extra_repr(self):
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
class PointNetFPModule(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
def forward(self, inputs):
if len(inputs) == 4:
points_coords, centers_coords, centers_features, time_emb = inputs
points_features = None
else:
points_coords, centers_coords, centers_features, points_features, time_emb = inputs
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
if points_features is not None:
interpolated_features = torch.cat(
[interpolated_features, points_features], dim=1
)
if time_emb is not None:
B,D,S = time_emb.shape
N = points_coords.shape[-1]
time_emb = time_emb[:,:,0:1].expand(-1,-1,N)
return self.mlp(interpolated_features), points_coords, time_emb
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
c = 0
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = k % 2 == 0 and k > 0 and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps, verbose=verbose)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels,
input_dim=3,
embed_dim=64, use_att=False, force_att=0,
dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1,
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True):
"""
Returns:
in_channels: the last output channels of the sa blocks
"""
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + input_dim
sa_layers, sa_in_channels = [], []
c = 0
num_centers = None
for conv_configs, sa_configs in sa_blocks:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0)
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(
PVConv, kernel_size=3,
resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se,
normalize=normalize, eps=eps, verbose=verbose)
if c == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
if sa_configs is not None:
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ),
out_channels=out_channels,
include_coordinates=True))
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
c += 1
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0])
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1, has_temb=1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
verbose=True):
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb,
out_channels=out_channels)
)
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3,
resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, # with_se_relu=True,
normalize=normalize, eps=eps,
verbose=verbose)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0])
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels

568
models/pvcnn2_ada.py Normal file
View file

@ -0,0 +1,568 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
copied and modified from source:
https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py
and functions under
https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules
"""
import copy
import functools
from loguru import logger
from einops import rearrange
import torch.nn as nn
import torch
import numpy as np
import third_party.pvcnn.functional as F
# from utils.checker import *
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
from .adagn import AdaGN
import os
quiet = int(os.environ.get('quiet', 0))
class SE3d(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
self.channel = channel
def __repr__(self):
return f"SE({self.channel}, {self.channel})"
def forward(self, inputs):
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)
class LinearAttention(nn.Module):
"""
copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159
"""
def __init__(self, dim, heads = 4, dim_head = 32, verbose=True):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
'''
Args:
x: torch.tensor (B,C,N), C=num-channels, N=num-points
Returns:
out: torch.tensor (B,C,N)
'''
x = x.unsqueeze(-1) # add w dimension
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = self.to_out(out)
out = out.squeeze(-1) # B,C,N,1 -> B,C,N
return out
def swish(input):
return input * torch.sigmoid(input)
class Swish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return swish(input)
class BallQuery(nn.Module):
def __init__(self, radius, num_neighbors, include_coordinates=True):
super().__init__()
self.radius = radius
self.num_neighbors = num_neighbors
self.include_coordinates = include_coordinates
@custom_bwd
def backward(self, *args, **kwargs):
return super().backward(*args, **kwargs)
@custom_fwd(cast_inputs=torch.float32)
def forward(self, points_coords, centers_coords, points_features=None):
# input: BCN, BCN
# neighbor_features: B,D(+3),Ncenter
points_coords = points_coords.contiguous()
centers_coords = centers_coords.contiguous()
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
if points_features is None:
assert self.include_coordinates, 'No Features For Grouping'
neighbor_features = neighbor_coordinates
else:
neighbor_features = F.grouping(points_features, neighbor_indices)
if self.include_coordinates:
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
return neighbor_features
def extra_repr(self):
return 'radius={}, num_neighbors={}{}'.format(
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
class SharedMLP(nn.Module):
def __init__(self, in_channels, out_channels, dim=1, cfg={}):
assert(len(cfg) > 0), cfg
super().__init__()
if dim==1:
conv = nn.Conv1d
else:
conv = nn.Conv2d
bn = functools.partial(AdaGN, dim, cfg)
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
layers = []
for oc in out_channels:
layers.append(conv(in_channels, oc, 1))
layers.append(bn(oc))
layers.append(Swish())
in_channels = oc
self.layers = nn.ModuleList(layers)
def forward(self, *inputs):
if len(inputs) == 1 and len(inputs[0]) == 4:
# try to fix thwn SharedMLP is the first layer
inputs = inputs[0]
if len(inputs) == 1:
raise NotImplementedError
elif len(inputs) == 4:
assert(len(inputs) == 4), 'input, style'
x, _, _, style = inputs
for l in self.layers:
if isinstance(l, AdaGN):
x = l(x, style)
else:
x = l(x)
return (x, *inputs[1:])
elif len(inputs) == 2:
x, style = inputs
for l in self.layers:
if isinstance(l, AdaGN):
x = l(x, style)
else:
x = l(x)
return x
else:
raise NotImplementedError
class Voxelization(nn.Module):
def __init__(self, resolution, normalize=True, eps=0):
super().__init__()
self.r = int(resolution)
self.normalize = normalize
self.eps = eps
def forward(self, features, coords):
# features: B,D,N
# coords: B,3,N
coords = coords.detach()
norm_coords = coords - coords.mean(2, keepdim=True)
if self.normalize:
norm_coords = norm_coords / (norm_coords.norm(
dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 +
self.eps) + 0.5
else:
norm_coords = (norm_coords + 1) / 2.0
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
vox_coords = torch.round(norm_coords).to(torch.int32)
if features is None:
return features, norm_coords
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
def extra_repr(self):
return 'resolution={}{}'.format(
self.r,
', normalized eps = {}'.format(self.eps) if self.normalize else '')
class PVConv(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, resolution,
normalize=1, eps=0, with_se=False,
add_point_feat=True, attention=False,
dropout=0.1, verbose=True,
cfg={}
):
super().__init__()
assert(len(cfg) > 0), cfg
self.resolution = resolution
self.voxelization = Voxelization(resolution,
normalize=normalize,
eps=eps)
# For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention)
NormLayer = functools.partial(AdaGN, 3, cfg)
voxel_layers = [
nn.Conv3d(in_channels ,
out_channels,
kernel_size, stride=1,
padding=kernel_size // 2),
NormLayer(out_channels),
Swish(),
nn.Dropout(dropout),
nn.Conv3d(out_channels, out_channels,
kernel_size, stride=1,
padding=kernel_size // 2),
NormLayer(out_channels)
]
if with_se:
voxel_layers.append(SE3d(out_channels))
self.voxel_layers = nn.ModuleList(voxel_layers)
if attention:
self.attn = LinearAttention(out_channels, verbose=verbose)
else:
self.attn = None
if add_point_feat:
self.point_features = SharedMLP(in_channels, out_channels, cfg=cfg)
self.add_point_feat = add_point_feat
def forward(self, inputs):
'''
Args:
inputs: tuple of features and coords
features: B,feat-dim,num-points
coords: B,3, num-points
time_emd: B,D; time embedding
style: B,D; global latent
Returns:
fused_features: in (B,out-feat-dim,num-points)
coords : in (B, 3 or 6, num_points); same as the input coords
'''
features = inputs[0]
coords_input= inputs[1]
time_emb = inputs[2]
style = inputs[3]
if coords_input.shape[1] > 3:
coords = coords_input[:,:3]
else:
coords = coords_input
assert (features.shape[0] == coords.shape[0]
), f'get feat: {features.shape} and {coords.shape}'
assert (features.shape[2] == coords.shape[2]
), f'get feat: {features.shape} and {coords.shape}'
assert (coords.shape[1] == 3
), f'expect coords: B,3,Npoint, get: {coords.shape}'
# features: B,D,N; point_features
# coords: B,3,N
voxel_features_4d, voxel_coords = self.voxelization(features, coords)
r = self.resolution
B = coords.shape[0]
for voxel_layers in self.voxel_layers:
if isinstance(voxel_layers, AdaGN):
voxel_features_4d = voxel_layers(voxel_features_4d, style)
else:
voxel_features_4d = voxel_layers(voxel_features_4d)
voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords,
r, self.training)
fused_features = voxel_features
if self.add_point_feat:
fused_features = fused_features + self.point_features(features, style)
if self.attn is not None:
fused_features = self.attn(fused_features)
return fused_features, coords_input, time_emb, style
class PointNetAModule(nn.Module):
def __init__(self, in_channels, out_channels, include_coordinates=True, cfg={}):
super().__init__()
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]]
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels]
mlps = []
total_out_channels = 0
for _out_channels in out_channels:
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=1, cfg=cfg)
)
total_out_channels += _out_channels[-1]
self.include_coordinates = include_coordinates
self.out_channels = total_out_channels
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords, time_emb, style = inputs
if self.include_coordinates:
features = torch.cat([features, coords], dim=1)
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
if len(self.mlps) > 1:
features_list = []
for mlp in self.mlps:
features_list.append(mlp(features, style).max(dim=-1, keepdim=True).values)
return torch.cat(features_list, dim=1), coords, time_emb
else:
return self.mlps[0](features, style).max(dim=-1, keepdim=True).values, coords, time_emb
def extra_repr(self):
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
class PointNetSAModule(nn.Module):
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True,
cfg={}):
super().__init__()
if not isinstance(radius, (list, tuple)):
radius = [radius]
if not isinstance(num_neighbors, (list, tuple)):
num_neighbors = [num_neighbors] * len(radius)
assert len(radius) == len(num_neighbors)
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]] * len(radius)
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels] * len(radius)
assert len(radius) == len(out_channels)
groupers, mlps = [], []
total_out_channels = 0
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
groupers.append(
BallQuery(radius=_radius, num_neighbors=_num_neighbors,
include_coordinates=include_coordinates)
)
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=2, cfg=cfg)
)
total_out_channels += _out_channels[-1]
self.num_centers = num_centers
self.out_channels = total_out_channels
self.groupers = nn.ModuleList(groupers)
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features = inputs[0]
coords = inputs[1] # B3N
style = inputs[3]
if coords.shape[1] > 3:
coords = coords[:,:3]
centers_coords = F.furthest_point_sample(coords, self.num_centers)
# centers_coords: B,D,N
S = centers_coords.shape[-1]
time_emb = inputs[2]
time_emb = time_emb[:,:,:S] if \
time_emb is not None and type(time_emb) is not dict \
else time_emb
features_list = []
c = 0
for grouper, mlp in zip(self.groupers, self.mlps):
c += 1
grouper_output = grouper(coords, centers_coords, features )
features_list.append(
mlp(grouper_output, style
).max(dim=-1).values
)
if len(features_list) > 1:
return torch.cat(features_list, dim=1), centers_coords, time_emb, style
else:
return features_list[0], centers_coords, time_emb, style
def extra_repr(self):
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
class PointNetFPModule(nn.Module):
def __init__(self, in_channels, out_channels, cfg={}):
super().__init__()
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1, cfg=cfg)
def forward(self, inputs):
if len(inputs) == 5:
points_coords, centers_coords, centers_features, time_emb, style = inputs
points_features = None
elif len(inputs) == 6:
points_coords, centers_coords, centers_features, points_features, time_emb, style = inputs
else:
raise NotImplementedError
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
if points_features is not None:
interpolated_features = torch.cat(
[interpolated_features, points_features], dim=1
)
if time_emb is not None:
B,D,S = time_emb.shape
N = points_coords.shape[-1]
time_emb = time_emb[:,:,0:1].expand(-1,-1,N)
return self.mlp(interpolated_features, style), points_coords, time_emb, style
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1, cfg={}):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc, cfg=cfg))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels,
input_dim=3,
embed_dim=64, use_att=False, force_att=0,
dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1,
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True,
cfg={}):
"""
Returns:
in_channels: the last output channels of the sa blocks
"""
assert(len(cfg) > 0), cfg
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + input_dim
sa_layers, sa_in_channels = [], []
c = 0
num_centers = None
for conv_configs, sa_configs in sa_blocks:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0)
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(
PVConv, kernel_size=3,
resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, # with_se_relu=True,
normalize=normalize, eps=eps, verbose=verbose, cfg=cfg)
if c == 0:
sa_blocks.append(block(in_channels, out_channels, cfg=cfg))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels, cfg=cfg))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
if sa_configs is not None:
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(cfg=cfg,
in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ),
out_channels=out_channels,
include_coordinates=True))
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
c += 1
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0])
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1, has_temb=1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
verbose=True, cfg={}):
assert(len(cfg) > 0), cfg
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb,
out_channels=out_channels,
cfg=cfg)
)
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = functools.partial(SharedMLP, cfg=cfg)
else:
block = functools.partial(PVConv, kernel_size=3,
resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, # with_se_relu=True,
normalize=normalize, eps=eps,
verbose=verbose,
cfg=cfg)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0])
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels

230
models/score_sde/resnet.py Normal file
View file

@ -0,0 +1,230 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" implement the gloabl prior for LION
"""
import torch.nn as nn
from loguru import logger
import functools
import torch
from ..utils import init_temb_fun, mask_inactive_variables
class SE(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.fc = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, 1, bias=False),
nn.Sigmoid()
)
def forward(self, inputs):
return inputs * self.fc(inputs)
class ResBlockSEClip(nn.Module):
"""
fixed the conv0 not used error in ResBlockSE
"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.non_linearity = nn.ReLU(inplace=True)
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv2d(input_dim*2, output_dim, 1, 1)
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
in_ch = self.output_dim
self.SE = SE(in_ch)
def forward(self, x, t):
## logger.info('x: {}, t: {}, input_dim={}', x.shape, t.shape, self.input_dim)
clip_feat = t[:, self.input_dim:].contiguous()
t = t[:,:self.input_dim].contiguous()
output = x + t
output = torch.cat([output, clip_feat], dim=1).contiguous()
output = self.conv1(output)
output = self.non_linearity(output)
output = self.conv2(output)
output = self.non_linearity(output)
output = self.SE(output)
shortcut = x
return shortcut + output
def __repr__(self):
return "ResBlockSEClip(%d, %d)"%(self.input_dim, self.output_dim)
class ResBlockSEDrop(nn.Module):
"""
fixed the conv0 not used error in ResBlockSE
"""
def __init__(self, input_dim, output_dim, dropout):
super().__init__()
self.non_linearity = nn.ReLU(inplace=True)
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
in_ch = self.output_dim
self.SE = SE(in_ch)
self.dropout = nn.Dropout(dropout)
self.dropout_ratio = dropout
def forward(self, x, t):
output = x + t
output = self.conv1(output)
output = self.non_linearity(output)
output = self.dropout(output)
output = self.conv2(output)
output = self.non_linearity(output)
output = self.SE(output)
shortcut = x
return shortcut + output
def __repr__(self):
return "ResBlockSE_withdropout(%d, %d, drop=%f)" % (
self.input_dim, self.output_dim, self.dropout_ratio)
class ResBlock(nn.Module):
def __init__(self, input_dim, output_dim):
# resample=None, act=nn.ELU(),
# normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
super().__init__()
self.non_linearity = nn.ELU()
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
in_ch = self.output_dim
self.normalize1 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6)
self.normalize2 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6)
def forward(self, x, t):
x = x + t
output = self.conv1(x)
output = self.normalize1(output)
output = self.non_linearity(output)
output = self.conv2(output)
output = self.normalize2(output)
output = self.non_linearity(output)
shortcut = x
return shortcut + output
def __repr__(self):
return "ResBlock(%d, %d)" % (self.input_dim, self.output_dim)
class Prior(nn.Module):
building_block = ResBlock
def __init__(self, args, num_input_channels, *oargs, **kwargs):
super().__init__()
# args: cfg.sde
# oargs: other argument: the global argument
self.condition_input = kwargs.get('condition_input', False)
self.cfg = oargs[0]
self.clip_forge_enable = self.cfg.clipforge.enable # kwargs.get('clipforge.enable', 0)
logger.info('[Build Resnet Prior] Has condition input: {}; clipforge {}; '
'learn_mixing_logit={}, ', self.condition_input,
self.clip_forge_enable, args.learn_mixing_logit)
self.act = act = nn.SiLU()
self.num_scales = args.num_scales_dae
self.num_input_channels = num_input_channels
self.nf = nf = args.num_channels_dae
num_cell_per_scale_dae = args.num_cell_per_scale_dae if 'num_cell_per_scale_dae' not in kwargs else kwargs[
'num_cell_per_scale_dae']
# take clip feature as input
if self.clip_forge_enable:
self.clip_feat_mapping = nn.Conv1d(self.cfg.clipforge.feat_dim, self.nf, 1)
# mixed_prediction #
self.mixed_prediction = args.mixed_prediction # This enables mixed prediction
if self.mixed_prediction:
logger.info('init-mixing_logit = {}, after sigmoid = {}',
args.mixing_logit_init, torch.sigmoid(torch.tensor(args.mixing_logit_init)))
assert(args.mixing_logit_init), f'require learning'
# if not args.learn_mixing_logit and args.hypara_mixing_logit:
# # not learn, treat it as hyparameters
# init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
# self.is_active = None
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
# init = torch.load('../exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
# self.is_active = None
# else:
if True:
init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
self.is_active = None
else: # no mixing_logit
self.mixing_logit = None
self.is_active = None
self.embedding_dim = args.embedding_dim
self.embedding_dim_mult = 4
self.temb_fun = init_temb_fun(args.embedding_type, args.embedding_scale, args.embedding_dim)
logger.info('[temb_fun] embedding_type={}, embedding_scale={}, embedding_dim={}',
args.embedding_type, args.embedding_scale, args.embedding_dim)
# exit()
modules = []
modules.append(nn.Conv2d(self.embedding_dim, self.embedding_dim * 4, 1, 1))
modules.append(nn.Conv2d(self.embedding_dim * 4, nf, 1, 1))
self.temb_layer = nn.Sequential(*modules)
modules = []
input_channels = num_input_channels
self.input_layer = nn.Conv2d(input_channels, nf, 1, 1)
in_ch = nf
for i_block in range(args.num_cell_per_scale_dae):
modules.append(self.building_block(nf, nf))
self.output_layer = nn.Conv2d(nf, input_channels, 1, 1)
self.all_modules = nn.ModuleList(modules)
def forward(self, x, t, **kwargs):
# timestep/noise_level embedding; only for continuous training
# time embedding
if t.dim() == 0:
t = t.expand(1)
temb = self.temb_fun(t)[:, :, None, None] # make it 4d
temb = self.temb_layer(temb)
if self.clip_forge_enable:
clip_feat = kwargs['clip_feat']
clip_feat = self.clip_feat_mapping(clip_feat[:, :, None])[:, :, :, None] # B,D -> BD1->B,D,1,1
if temb.shape[0] == 1 and temb.shape[0] < clip_feat.shape[0]:
temb = temb.expand(clip_feat.shape[0], -1, -1, -1)
temb = torch.cat([temb, clip_feat], dim=1) # add to temb feature
# mask out inactive variables
if self.mixed_prediction and self.is_active is not None:
x = mask_inactive_variables(x, self.is_active)
x = self.input_layer(x)
for layer in self.all_modules:
enc_input = x
x = layer(enc_input, temb)
h = self.output_layer(x)
return h
class PriorSEDrop(Prior):
def __init__(self, *args, **kwargs):
self.building_block = functools.partial(ResBlockSEDrop, dropout=args[0].dropout)
super().__init__(*args, **kwargs)
class PriorSEClip(Prior):
building_block = ResBlockSEClip
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View file

@ -0,0 +1,54 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch.nn as nn
from loguru import logger
from .pvcnn2 import create_pointnet2_sa_components
# implement the global encoder for VAE model
class PointNetPlusEncoder(nn.Module):
sa_blocks = [
[[32, 2, 32], [1024, 0.1, 32, [32, 32]]],
[[32, 1, 16], [256, 0.2, 32, [32, 64]]]
]
force_att = 0 # add attention to all layers
def __init__(self, zdim, input_dim, extra_feature_channels=0, args={}):
super().__init__()
sa_blocks = self.sa_blocks
layers, sa_in_channels, channels_sa_features, _ = \
create_pointnet2_sa_components(sa_blocks,
extra_feature_channels, input_dim=input_dim,
embed_dim=0, force_att=self.force_att,
use_att=True, with_se=True)
self.mlp = nn.Linear(channels_sa_features, zdim*2)
self.zdim = zdim
logger.info('[Encoder] zdim={}, out_sigma={}; force_att: {}', zdim, True, self.force_att)
self.layers = nn.ModuleList(layers)
self.voxel_dim = [n[1][-1][-1] for n in self.sa_blocks]
def forward(self, x):
"""
Args:
x: B,N,3
Returns:
mu, sigma: B,D
"""
output = {}
x = x.transpose(1, 2) # B,3,N
xyz = x ## x[:,:3,:]
features = x
for layer_id, layer in enumerate(self.layers):
features, xyz, _ = layer( (features, xyz, None) )
# features: B,D,N; xyz: B,3,N
features = features.max(-1)[0]
features = self.mlp(features)
mu_1d, sigma_1d = features[:, :self.zdim], features[:, self.zdim:]
output.update({'mu_1d': mu_1d, 'sigma_1d': sigma_1d})
return output

52
models/utils.py Normal file
View file

@ -0,0 +1,52 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import math
import torch.nn as nn
def mask_inactive_variables(x, is_active):
x = x * is_active
return x
class PositionalEmbedding(nn.Module):
def __init__(self, embedding_dim, scale):
super(PositionalEmbedding, self).__init__()
self.embedding_dim = embedding_dim
self.scale = scale
def forward(self, timesteps):
assert len(timesteps.shape) == 1
timesteps = timesteps * self.scale
half_dim = self.embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
return emb
class RandomFourierEmbedding(nn.Module):
def __init__(self, embedding_dim, scale):
super(RandomFourierEmbedding, self).__init__()
self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False)
def forward(self, timesteps):
emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359)
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
def init_temb_fun(embedding_type, embedding_scale, embedding_dim):
if embedding_type == 'positional':
temb_fun = PositionalEmbedding(embedding_dim, embedding_scale)
elif embedding_type == 'fourier':
temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale)
else:
raise NotImplementedError
return temb_fun

339
models/vae_adain.py Normal file
View file

@ -0,0 +1,339 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import numpy as np
from loguru import logger
import importlib
import torch.nn as nn
from .distributions import Normal
from utils.model_helper import import_model
from utils.model_helper import loss_fn
from utils import utils as helper
class Model(nn.Module):
def __init__(self, args):
super().__init__()
self.num_total_iter = 0
self.args = args
self.input_dim = args.ddpm.input_dim
latent_dim = args.shapelatent.latent_dim
self.latent_dim = latent_dim
self.kl_weight = args.shapelatent.kl_weight
self.num_points = args.data.tr_max_sample_points
# ---- global ---- #
# build encoder
self.style_encoder = import_model(args.latent_pts.style_encoder)(
zdim=args.latent_pts.style_dim,
input_dim=self.input_dim,
args=args)
if len(args.latent_pts.style_mlp):
self.style_mlp = import_model(args.latent_pts.style_mlp)(args)
else:
self.style_mlp = None
self.encoder = import_model(args.shapelatent.encoder_type)(
zdim=latent_dim,
input_dim=self.input_dim,
args=args)
# build decoder
self.decoder = import_model(args.shapelatent.decoder_type)(
context_dim=latent_dim,
point_dim=args.ddpm.input_dim,
args=args)
logger.info('[Build Model] style_encoder: {}, encoder: {}, decoder: {}',
args.latent_pts.style_encoder,
args.shapelatent.encoder_type,
args.shapelatent.decoder_type)
@torch.no_grad()
def encode(self, x, class_label=None):
batch_size, _, point_dim = x.size()
assert(x.shape[2] == self.input_dim), f'expect input in ' \
f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}'
x_0_target = x
latent_list = []
all_eps = []
all_log_q = []
if self.args.data.cond_on_cat:
assert(class_label is not None), f'require class label input for cond on cat'
cls_emb = self.class_embedding(class_label)
enc_input = x, cls_emb
else:
enc_input = x
# ---- global style encoder ---- #
z = self.style_encoder(enc_input)
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
z_global = dist.sample()[0]
all_eps.append(z_global)
all_log_q.append(dist.log_p(z_global))
latent_list.append( [z_global, z_mu, z_sigma] )
# ---- original encoder ---- #
style = z_global # torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
style = self.style_mlp(style) if self.style_mlp is not None else style
z = self.encoder([x, style])
z_mu, z_sigma = z['mu_1d'], z['sigma_1d']
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
z_local = dist.sample()[0]
all_eps.append(z_local)
all_log_q.append(dist.log_p(z_local))
latent_list.append( [z_local, z_mu, z_sigma] )
all_eps = self.compose_eps(all_eps)
if self.args.data.cond_on_cat:
return all_eps, all_log_q, latent_list, cls_emb
else:
return all_eps, all_log_q, latent_list
def compose_eps(self, all_eps):
return torch.cat(all_eps, dim=1) # style: [B,D1], latent pts: [B,ND2]
def decompose_eps(self, all_eps):
eps_style = all_eps[:,:self.args.latent_pts.style_dim]
eps_local = all_eps[:,self.args.latent_pts.style_dim:]
return [eps_style, eps_local]
def encode_global(self, x, class_label=None):
batch_size, N, point_dim = x.size()
if self.args.data.cond_on_cat:
assert(class_label is not None), f'require class label input for cond on cat'
cls_emb = self.class_embedding(class_label)
enc_input = x, cls_emb
else:
enc_input = x
z = self.style_encoder(enc_input)
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
return dist
def global2style(self, style): ##, cls_emb=None):
Ndim = len(style.shape)
if Ndim == 4:
style = style.squeeze(-1).squeeze(-1)
style = self.style_mlp(style) if self.style_mlp is not None else style
if Ndim == 4:
style = style.unsqueeze(-1).unsqueeze(-1)
return style
def encode_local(self, x, style):
# ---- original encoder ---- #
z = self.encoder([x, style])
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
return dist
def recont(self, x, target=None, class_label=None, cls_emb=None):
batch_size, N, point_dim = x.size()
assert(x.shape[2] == self.input_dim), f'expect input in ' \
f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}'
x_0_target = x if target is None else target
latent_list = []
all_eps = []
all_log_q = []
# ---- global style encoder ---- #
if self.args.data.cond_on_cat:
if class_label is not None:
assert(class_label is not None)
cls_emb = self.class_embedding(class_label)
else:
assert(cls_emb is not None)
enc_input = x, cls_emb
else:
enc_input = x
z = self.style_encoder(enc_input)
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
z_global = dist.sample()[0]
all_eps.append(z_global)
all_log_q.append(dist.log_p(z_global))
latent_list.append( [z_global, z_mu, z_sigma] )
# ---- original encoder ---- #
style = torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
style = self.style_mlp(style) if self.style_mlp is not None else style
z = self.encoder([x, style])
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
z_local = dist.sample()[0]
all_eps.append(z_local)
all_log_q.append(dist.log_p(z_local))
latent_list.append( [z_local, z_mu, z_sigma] )
# ---- decoder ---- #
x_0_pred = self.decoder(None, beta=None, context=z_local, style=style) # (B,ncenter,3)
make_4d = lambda x: x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1)
all_eps = [make_4d(e) for e in all_eps]
all_log_q = [make_4d(e) for e in all_log_q]
output = {
'all_eps': all_eps,
'all_log_q': all_log_q,
'latent_list': latent_list,
'x_0_pred':x_0_pred,
'x_0_target': x_0_target,
'x_t': torch.zeros_like(x_0_target),
't': torch.zeros(batch_size),
'x_0': x_0_target
}
output['hist/global_var'] = latent_list[0][2].exp()
if 'LatentPoint' in self.args.shapelatent.decoder_type:
latent_shape = [batch_size, -1, self.latent_dim + self.input_dim]
if 'Hir' in self.args.shapelatent.decoder_type:
latent_pts = z_local[:,:-self.args.latent_pts.latent_dim_ext[0]].view(*latent_shape)[:,:,:3].contiguous().clone()
else:
latent_pts = z_local.view(*latent_shape)[:,:,:self.input_dim].contiguous().clone()
output['vis/latent_pts'] = latent_pts.detach().cpu().view(batch_size,
-1, self.input_dim) # B,N,3
output['final_pred'] = output['x_0_pred']
return output
def get_loss(self, x, writer=None, it=None, ## weight_loss_1=1,
noisy_input=None, class_label=None, **kwargs):
"""
shapelatent z ~ q(z|x_0)
and x_t ~ q(x_t|x_0, t), t ~ Uniform(T)
forward and get x_{t-1} ~ p(x_{t-1} | x_t, z)
Args:
x: Input point clouds, (B, N, d).
"""
## kl_weight = self.kl_weight
if self.args.trainer.anneal_kl and self.num_total_iter > 0:
global_step = it
kl_weight = helper.kl_coeff(step=global_step,
total_step=self.args.sde.kl_anneal_portion_vada * self.num_total_iter,
constant_step=self.args.sde.kl_const_portion_vada * self.num_total_iter,
min_kl_coeff=self.args.sde.kl_const_coeff_vada,
max_kl_coeff=self.args.sde.kl_max_coeff_vada)
else:
kl_weight = self.kl_weight
batch_size = x.shape[0]
# CHECKDIM(x, 2, self.input_dim)
assert(x.shape[2] == self.input_dim)
inputs = noisy_input if noisy_input is not None else x
output = self.recont(inputs, target=x, class_label=class_label)
x_0_pred, x_0_target = output['x_0_pred'], output['x_0_target']
loss_0 = loss_fn(x_0_pred, x_0_target, self.args.ddpm.loss_type,
self.input_dim, batch_size).mean()
rec_loss = loss_0
output['print/loss_0'] = loss_0
output['rec_loss'] = rec_loss
# Loss
## z_global, z_sigma, z_mu = output['z_global'], output['z_sigma'], output['z_mu']
kl_term_list = []
weighted_kl_terms = []
for pairs_id, pairs in enumerate(output['latent_list']):
cz, cmu, csigma = pairs
log_sigma = csigma
kl_term_close = (0.5*log_sigma.exp()**2 +
0.5*cmu**2 - log_sigma - 0.5).view(
batch_size, -1)
if 'LatentPoint' in self.args.shapelatent.decoder_type and 'Hir' not in self.args.shapelatent.decoder_type:
if pairs_id == 1:
latent_shape = [batch_size, -1, self.latent_dim + self.input_dim]
kl_pt = kl_term_close.view(*latent_shape)[:,:,:self.input_dim]
kl_feat = kl_term_close.view(*latent_shape)[:,:,self.input_dim:]
weighted_kl_terms.append(kl_pt.sum(2).sum(1) * self.args.latent_pts.weight_kl_pt)
weighted_kl_terms.append(kl_feat.sum(2).sum(1) * self.args.latent_pts.weight_kl_feat)
output['print/kl_pt%d'%pairs_id] = kl_pt.sum(2).sum(1)
output['print/kl_feat%d'%pairs_id] = kl_feat.sum(2).sum(1)
output['print/z_var_pt%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,:self.input_dim]
).exp()**2
output['print/z_var_feat%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,self.input_dim:]
).exp()**2
output['print/z_mean_feat%d'%pairs_id] = cmu.view(*latent_shape)[:,:,self.input_dim:].mean()
elif pairs_id == 0:
kl_style = kl_term_close
weighted_kl_terms.append(kl_style.sum(-1) * self.args.latent_pts.weight_kl_glb)
output['print/kl_glb%d'%pairs_id] = kl_style.sum(-1)
output['print/z_var_glb%d'%pairs_id] = (log_sigma).exp()**2
kl_term_close = kl_term_close.sum(-1)
kl_term_list.append(kl_term_close)
output['print/kl_%d'%pairs_id] = kl_term_close
output['print/z_mean_%d'%pairs_id] = cmu.mean()
output['print/z_mag_%d'%pairs_id] = cmu.abs().max()
# logger.info('log_sigma: {}, mean: {}', log_sigma.shape, (log_sigma.exp()**2).mean())
output['print/z_var_%d'%pairs_id] = (log_sigma).exp()**2
output['print/z_logsigma_%d'%pairs_id] = log_sigma
output['print/kl_weight'] = kl_weight
loss_recons = rec_loss
if len(weighted_kl_terms) > 0:
kl = kl_weight * sum(weighted_kl_terms)
else:
kl = kl_weight * sum(kl_term_list)
loss = kl + loss_recons * self.args.weight_recont
output['msg/kl'] = kl
output['msg/rec'] = loss_recons
output['loss'] = loss
return output
def pz(self, w):
return w
def sample(self, num_samples=10, temp=None, decomposed_eps=[],
enable_autocast=False, device_str='cuda', cls_emb=None):
""" currently not support the samples of local level
Return:
model_output: [B,N,D]
"""
batch_size = num_samples
center_emd = None
if 'LatentPoint' in self.args.shapelatent.decoder_type:
# Latent Point Model: latent shape; B; ND
latent_shape = (num_samples, self.num_points*(self.latent_dim+self.input_dim))
style_latent_shape = (num_samples, self.args.latent_pts.style_dim)
else:
raise NotImplementedError
if len(decomposed_eps) == 0:
z_local = torch.zeros(*latent_shape).to(
torch.device(device_str)).normal_()
z_global = torch.zeros(*style_latent_shape).to(
torch.device(device_str)).normal_()
else:
z_global = decomposed_eps[0]
z_local = decomposed_eps[1]
z_local = z_local.view(*latent_shape)
z_global = z_global.view(style_latent_shape)
style = z_global
style = self.style_mlp(style) if self.style_mlp is not None else style
x_0_pred = self.decoder(None, beta=None,
context=z_local, style=z_global) # (B,ncenter,3)
## CHECKSIZE(x_0_pred, (batch_size,self.num_points,[3,6]))
return x_0_pred
def latent_shape(self):
return [
[self.args.latent_pts.style_dim, 1, 1],
[self.num_points*(self.latent_dim+self.input_dim),1,1]
]

43
script/compute_score.py Normal file
View file

@ -0,0 +1,43 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import sys
sys.path.append('.')
from utils.eval_helper import compute_score
# samples = sys.argv[1]
# ref = sys.argv[2]
samples = './lion_ckpt/unconditional/car/samples.pt'
ref = './datasets/test_data/ref_val_car.pt'
compute_score(samples, ref_name=ref)
"""
will get:
[Test] MinMatDis | CD 0.000913 | EMD 0.007523
[Test] Coverage | CD 0.500000 | EMD 0.565341
[Test] 1NN-Accur | CD 0.534091 | EMD 0.511364
[Test] JsnShnDis | 0.009229
"""
samples = './lion_ckpt/unconditional/chair/samples.pt'
ref = './datasets/test_data/ref_val_chair.pt'
compute_score(samples, ref_name=ref)
"""
[Test] MinMatDis | CD 0.002643 | EMD 0.015516
[Test] Coverage | CD 0.489426 | EMD 0.521148
[Test] 1NN-Accur | CD 0.537009 | EMD 0.523414
[Test] JsnShnDis | 0.013535
"""
samples = './lion_ckpt/unconditional/chair/samples.pt'
ref = './datasets/test_data/ref_val_chair.pt'
compute_score(samples, ref_name=ref)
"""
[Test] MinMatDis | CD 0.000221 | EMD 0.003706
[Test] Coverage | CD 0.471605 | EMD 0.496296
[Test] 1NN-Accur | CD 0.674074 | EMD 0.612346
[Test] JsnShnDis | 0.060703
"""

View file

@ -0,0 +1,3 @@
*__pycache__*
/tmp
tmp/*

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 ThibaultGROUEIX
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

104
third_party/ChamferDistancePytorch/README.md vendored Executable file
View file

@ -0,0 +1,104 @@
* adapted from https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
----------------------------------
# Pytorch Chamfer Distance.
Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations.
NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly.
- [x] F - Score
### CUDA VERSION
- [x] JIT compilation
- [x] Supports multi-gpu
- [x] 2D point clouds.
- [x] 3D point clouds.
- [x] 5D point clouds.
- [x] Contiguous() safe.
### Python Version
- [x] Supports any dimension
### Usage
```python
import torch, chamfer3D.dist_chamfer_3D, fscore
chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
points1 = torch.rand(32, 1000, 3).cuda()
points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda()
dist1, dist2, idx1, idx2 = chamLoss(points1, points2)
f_score, precision, recall = fscore.fscore(dist1, dist2)
```
### Add it to your project as a submodule
```shell
git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch
```
### Benchmark: [forward + backward] pass
- [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4
- [x] p1 : 32 x 2000 x dim
- [x] p2 : 32 x 1000 x dim
| *Timing (sec * 1000)* | 2D | 3D | 5D |
| ---------- | -------- | ------- | ------- |
| **Cuda Compiled** | **1.2** | 1.4 |1.8 |
| **Cuda JIT** | 1.3 | **1.4** |**1.5** |
| **Python** | 37 | 37 | 37 |
| *Memory (MB)* | 2D | 3D | 5D |
| ---------- | -------- | ------- | ------- |
| **Cuda Compiled** | 529 | 529 | 549 |
| **Cuda JIT** | **520** | **529** |**549** |
| **Python** | 2495 | 2495 | 2495 |
### What is the chamfer distance ?
[Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning
### Aknowledgment
Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu).
JIT cool trick from [Christian Diller](https://github.com/chrdiller)
### Troubleshoot
- `Undefined symbol: Zxxxxxxxxxxxxxxxxx `:
--> Fix: Make sure to `import torch` before you `import chamfer`.
--> Use pytorch.version >= 1.1.0
- [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167)
```shell
wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
```
#### TODO:
* Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions

View file

@ -0,0 +1,182 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*2];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*2;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*2+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*2+0];
float y1=xyz[(i*n+j)*2+1];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&2);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*2+2]-x1;
float y2=buf[k*2+3]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*2+4]-x1;
float y2=buf[k*2+5]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*2+6]-x1;
float y2=buf[k*2+7]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*2+2]-x1;
float y2=buf[k*2+3]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*2+4]-x1;
float y2=buf[k*2+5]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*2+6]-x1;
float y2=buf[k*2+7]-y1;
float d=x2*x2+y2*y2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*2+0]-x1;
float y2=buf[k*2+1]-y1;
float d=x2*x2+y2*y2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*2+0];
float y1=xyz1[(i*n+j)*2+1];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*2+0];
float y2=xyz2[(i*m+j2)*2+1];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*2+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*2+1]),g*(y1-y2));
atomicAdd(&(grad_xyz2[(i*m+j2)*2+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*2+1]),-(g*(y1-y2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,80 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_2D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 2D")
cur_path = os.path.dirname(os.path.abspath(__file__))
build_path = cur_path.replace('chamfer2D', 'tmp')
os.makedirs(build_path, exist_ok=True)
from torch.utils.cpp_extension import load
chamfer_2D = load(name="chamfer_2D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]),
], build_directory=build_path)
print("Loaded JIT 2D CUDA chamfer distance")
else:
import chamfer_2D
print("Loaded compiled 2D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_2DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, dim = xyz1.size()
assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
_, m, dim = xyz2.size()
assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
device = xyz1.device
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_2D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_2DDist(nn.Module):
def __init__(self):
super(chamfer_2DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_2DFunction.apply(input1, input2)

View file

@ -0,0 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_2D',
ext_modules=[
CUDAExtension('chamfer_2D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']),
]),
],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,196 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*3+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*3+0];
float y1=xyz[(i*n+j)*3+1];
float z1=xyz[(i*n+j)*3+2];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&3);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*3+3]-x1;
float y2=buf[k*3+4]-y1;
float z2=buf[k*3+5]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*3+6]-x1;
float y2=buf[k*3+7]-y1;
float z2=buf[k*3+8]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*3+9]-x1;
float y2=buf[k*3+10]-y1;
float z2=buf[k*3+11]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*3+3]-x1;
float y2=buf[k*3+4]-y1;
float z2=buf[k*3+5]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*3+6]-x1;
float y2=buf[k*3+7]-y1;
float z2=buf[k*3+8]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*3+9]-x1;
float y2=buf[k*3+10]-y1;
float z2=buf[k*3+11]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*3+0]-x1;
float y2=buf[k*3+1]-y1;
float z2=buf[k*3+2]-z1;
float d=x2*x2+y2*y2+z2*z2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*3+0];
float y1=xyz1[(i*n+j)*3+1];
float z1=xyz1[(i*n+j)*3+2];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*3+0];
float y2=xyz2[(i*m+j2)*3+1];
float z2=xyz2[(i*m+j2)*3+2];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,133 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
cur_path = os.path.dirname(os.path.abspath(__file__))
build_path = cur_path.replace('chamfer3D', 'tmp')
os.makedirs(build_path, exist_ok=True)
from torch.utils.cpp_extension import load
chamfer_3D = load(name="chamfer_3D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
], build_directory=build_path)
#chamfer_found = importlib.find_loader("chamfer_3D") is not None
#if not chamfer_found:
# ## Cool trick from https://github.com/chrdiller
# print("Jitting Chamfer 3D")
# cur_path = os.path.dirname(os.path.abspath(__file__))
# build_path = cur_path.replace('chamfer3D', 'tmp')
# os.makedirs(build_path, exist_ok=True)
#
# from torch.utils.cpp_extension import load
# chamfer_3D = load(name="chamfer_3D",
# sources=[
# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
# "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
# ], build_directory=build_path)
# print("Loaded JIT 3D CUDA chamfer distance")
#
#else:
# import chamfer_3D
# print("Loaded compiled 3D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_3DFunction(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, xyz1, xyz2):
batchsize, n, dim = xyz1.size()
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
_, m, dim = xyz2.size()
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
device = xyz1.device
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
@custom_bwd
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_3D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_3DDist(nn.Module):
def __init__(self):
super(chamfer_3DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_3DFunction.apply(input1, input2)
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_3DFunction_noGrad(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, dim = xyz1.size()
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
_, m, dim = xyz2.size()
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
device = xyz1.device
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
return dist1, dist2, idx1, idx2
class chamfer_3DDist_nograd(nn.Module):
def __init__(self):
super(chamfer_3DDist_nograd, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_3DFunction_noGrad.apply(input1, input2)

View file

@ -0,0 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_3D',
ext_modules=[
CUDAExtension('chamfer_3D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
]),
],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,223 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=2048;
__shared__ float buf[batch*5];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*5;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*5+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*5+0];
float y1=xyz[(i*n+j)*5+1];
float r1=xyz[(i*n+j)*5+2];
float g1=xyz[(i*n+j)*5+3];
float b1=xyz[(i*n+j)*5+4];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&5);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*5+0]-x1;
float y2=buf[k*5+1]-y1;
float r2=buf[k*5+2]-r1;
float g2=buf[k*5+3]-g1;
float b2=buf[k*5+4]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*5+5]-x1;
float y2=buf[k*5+6]-y1;
float r2=buf[k*5+7]-r1;
float g2=buf[k*5+8]-g1;
float b2=buf[k*5+9]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*5+10]-x1;
float y2=buf[k*5+11]-y1;
float r2=buf[k*5+12]-r1;
float g2=buf[k*5+13]-g1;
float b2=buf[k*5+14]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*5+15]-x1;
float y2=buf[k*5+16]-y1;
float r2=buf[k*5+17]-r1;
float g2=buf[k*5+18]-g1;
float b2=buf[k*5+19]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*5+0]-x1;
float y2=buf[k*5+1]-y1;
float r2=buf[k*5+2]-r1;
float g2=buf[k*5+3]-g1;
float b2=buf[k*5+4]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*5+5]-x1;
float y2=buf[k*5+6]-y1;
float r2=buf[k*5+7]-r1;
float g2=buf[k*5+8]-g1;
float b2=buf[k*5+9]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*5+10]-x1;
float y2=buf[k*5+11]-y1;
float r2=buf[k*5+12]-r1;
float g2=buf[k*5+13]-g1;
float b2=buf[k*5+14]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*5+15]-x1;
float y2=buf[k*5+16]-y1;
float r2=buf[k*5+17]-r1;
float g2=buf[k*5+18]-g1;
float b2=buf[k*5+19]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*5+0]-x1;
float y2=buf[k*5+1]-y1;
float r2=buf[k*5+2]-r1;
float g2=buf[k*5+3]-g1;
float b2=buf[k*5+4]-b1;
float d=x2*x2+y2*y2+r2*r2+g2*g2+b2*b2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*5+0];
float y1=xyz1[(i*n+j)*5+1];
float r1=xyz1[(i*n+j)*5+2];
float g1=xyz1[(i*n+j)*5+3];
float b1=xyz1[(i*n+j)*5+4];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*5+0];
float y2=xyz2[(i*m+j2)*5+1];
float r2=xyz2[(i*m+j2)*5+2];
float g2=xyz2[(i*m+j2)*5+3];
float b2=xyz2[(i*m+j2)*5+4];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*5+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+1]),g*(y1-y2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+2]),g*(r1-r2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+3]),g*(g1-g2));
atomicAdd(&(grad_xyz1[(i*n+j)*5+4]),g*(b1-b2));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+1]),-(g*(y1-y2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+2]),-(g*(r1-r2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+3]),-(g*(g1-g2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*5+4]),-(g*(b1-b2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,82 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_5D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 5D")
cur_path = os.path.dirname(os.path.abspath(__file__))
build_path = cur_path.replace('chamfer5D', 'tmp')
os.makedirs(build_path, exist_ok=True)
from torch.utils.cpp_extension import load
chamfer_5D = load(name="chamfer_5D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]),
], build_directory=build_path)
print("Loaded JIT 5D CUDA chamfer distance")
else:
import chamfer_5D
print("Loaded compiled 5D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_5DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, dim = xyz1.size()
assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
_, m, dim = xyz2.size()
assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
device = xyz1.device
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_5D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_5DDist(nn.Module):
def __init__(self):
super(chamfer_5DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_5DFunction.apply(input1, input2)

View file

@ -0,0 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_5D',
ext_modules=[
CUDAExtension('chamfer_5D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']),
]),
],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,237 @@
#include <stdio.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=2048;
__shared__ float buf[batch*6];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int k2=0;k2<m;k2+=batch){
int end_k=min(m,k2+batch)-k2;
for (int j=threadIdx.x;j<end_k*6;j+=blockDim.x){
buf[j]=xyz2[(i*m+k2)*6+j];
}
__syncthreads();
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz[(i*n+j)*6+0];
float y1=xyz[(i*n+j)*6+1];
float z1=xyz[(i*n+j)*6+2];
float nx1=xyz[(i*n+j)*6+3];
float ny1=xyz[(i*n+j)*6+4];
float nz1=xyz[(i*n+j)*6+5];
int best_i=0;
float best=0;
int end_ka=end_k-(end_k&6);
if (end_ka==batch){
for (int k=0;k<batch;k+=4){
{
float x2=buf[k*6+0]-x1;
float y2=buf[k*6+1]-y1;
float z2=buf[k*6+2]-z1;
float nx2=buf[k*6+3]-nx1;
float ny2=buf[k*6+4]-ny1;
float nz2=buf[k*6+5]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*6+6]-x1;
float y2=buf[k*6+7]-y1;
float z2=buf[k*6+8]-z1;
float nx2=buf[k*6+9]-nx1;
float ny2=buf[k*6+10]-ny1;
float nz2=buf[k*6+11]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*6+12]-x1;
float y2=buf[k*6+13]-y1;
float z2=buf[k*6+14]-z1;
float nx2=buf[k*6+15]-nx1;
float ny2=buf[k*6+16]-ny1;
float nz2=buf[k*6+17]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*6+18]-x1;
float y2=buf[k*6+19]-y1;
float z2=buf[k*6+20]-z1;
float nx2=buf[k*6+21]-nx1;
float ny2=buf[k*6+22]-ny1;
float nz2=buf[k*6+23]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}else{
for (int k=0;k<end_ka;k+=4){
{
float x2=buf[k*6+0]-x1;
float y2=buf[k*6+1]-y1;
float z2=buf[k*6+2]-z1;
float nx2=buf[k*6+3]-nx1;
float ny2=buf[k*6+4]-ny1;
float nz2=buf[k*6+5]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
{
float x2=buf[k*6+6]-x1;
float y2=buf[k*6+7]-y1;
float z2=buf[k*6+8]-z1;
float nx2=buf[k*6+9]-nx1;
float ny2=buf[k*6+10]-ny1;
float nz2=buf[k*6+11]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (d<best){
best=d;
best_i=k+k2+1;
}
}
{
float x2=buf[k*6+12]-x1;
float y2=buf[k*6+13]-y1;
float z2=buf[k*6+14]-z1;
float nx2=buf[k*6+15]-nx1;
float ny2=buf[k*6+16]-ny1;
float nz2=buf[k*6+17]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (d<best){
best=d;
best_i=k+k2+2;
}
}
{
float x2=buf[k*6+18]-x1;
float y2=buf[k*6+19]-y1;
float z2=buf[k*6+20]-z1;
float nx2=buf[k*6+21]-nx1;
float ny2=buf[k*6+22]-ny1;
float nz2=buf[k*6+23]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (d<best){
best=d;
best_i=k+k2+3;
}
}
}
}
for (int k=end_ka;k<end_k;k++){
float x2=buf[k*6+0]-x1;
float y2=buf[k*6+1]-y1;
float z2=buf[k*6+2]-z1;
float nx2=buf[k*6+3]-nx1;
float ny2=buf[k*6+4]-ny1;
float nz2=buf[k*6+5]-nz1;
float d=x2*x2+y2*y2+z2*z2+nx2*nx2+ny2*ny2+nz2*nz2;
if (k==0 || d<best){
best=d;
best_i=k+k2;
}
}
if (k2==0 || result[(i*n+j)]>best){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
float x1=xyz1[(i*n+j)*6+0];
float y1=xyz1[(i*n+j)*6+1];
float z1=xyz1[(i*n+j)*6+2];
float nx1=xyz1[(i*n+j)*6+3];
float ny1=xyz1[(i*n+j)*6+4];
float nz1=xyz1[(i*n+j)*6+5];
int j2=idx1[i*n+j];
float x2=xyz2[(i*m+j2)*6+0];
float y2=xyz2[(i*m+j2)*6+1];
float z2=xyz2[(i*m+j2)*6+2];
float nx2=xyz2[(i*m+j2)*6+3];
float ny2=xyz2[(i*m+j2)*6+4];
float nz2=xyz2[(i*m+j2)*6+5];
float g=grad_dist1[i*n+j]*2;
atomicAdd(&(grad_xyz1[(i*n+j)*6+0]),g*(x1-x2));
atomicAdd(&(grad_xyz1[(i*n+j)*6+1]),g*(y1-y2));
atomicAdd(&(grad_xyz1[(i*n+j)*6+2]),g*(z1-z2));
atomicAdd(&(grad_xyz1[(i*n+j)*6+3]),g*(nx1-nx2));
atomicAdd(&(grad_xyz1[(i*n+j)*6+4]),g*(ny1-ny2));
atomicAdd(&(grad_xyz1[(i*n+j)*6+5]),g*(nz1-nz2));
atomicAdd(&(grad_xyz2[(i*m+j2)*6+0]),-(g*(x1-x2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*6+1]),-(g*(y1-y2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*6+2]),-(g*(z1-z2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*6+3]),-(g*(nx1-nx2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*6+4]),-(g*(ny1-ny2)));
atomicAdd(&(grad_xyz2[(i*m+j2)*6+5]),-(g*(nz1-nz2)));
}
}
}
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
// cudaMemset(grad_xyz1,0,b*n*3*4);
// cudaMemset(grad_xyz2,0,b*m*3*4);
const auto batch_size = xyz1.size(0);
const auto n = xyz1.size(1); //num_points point cloud A
const auto m = xyz2.size(1); //num_points point cloud B
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
//THError("aborting");
return 0;
}
return 1;
}

View file

@ -0,0 +1,33 @@
#include <torch/torch.h>
#include <vector>
///TMP
//#include "common.h"
/// NOT TMP
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
}

View file

@ -0,0 +1,82 @@
from torch import nn
from torch.autograd import Function
import torch
import importlib
import os
chamfer_found = importlib.find_loader("chamfer_6D") is not None
if not chamfer_found:
## Cool trick from https://github.com/chrdiller
print("Jitting Chamfer 6D")
cur_path = os.path.dirname(os.path.abspath(__file__))
build_path = cur_path.replace('chamfer6D', 'tmp')
os.makedirs(build_path, exist_ok=True)
from torch.utils.cpp_extension import load
chamfer_6D = load(name="chamfer_6D",
sources=[
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer6D.cu"]),
], build_directory=build_path)
print("Loaded JIT 6D CUDA chamfer distance")
else:
import chamfer_6D
print("Loaded compiled 6D CUDA chamfer distance")
# Chamfer's distance module @thibaultgroueix
# GPU tensors only
class chamfer_6DFunction(Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
batchsize, n, dim = xyz1.size()
assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
_, m, dim = xyz2.size()
assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
device = xyz1.device
device = xyz1.device
dist1 = torch.zeros(batchsize, n)
dist2 = torch.zeros(batchsize, m)
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
dist1 = dist1.to(device)
dist2 = dist2.to(device)
idx1 = idx1.to(device)
idx2 = idx2.to(device)
torch.cuda.set_device(device)
chamfer_6D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
graddist1 = graddist1.contiguous()
graddist2 = graddist2.contiguous()
device = graddist1.device
gradxyz1 = torch.zeros(xyz1.size())
gradxyz2 = torch.zeros(xyz2.size())
gradxyz1 = gradxyz1.to(device)
gradxyz2 = gradxyz2.to(device)
chamfer_6D.backward(
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
)
return gradxyz1, gradxyz2
class chamfer_6DDist(nn.Module):
def __init__(self):
super(chamfer_6DDist, self).__init__()
def forward(self, input1, input2):
input1 = input1.contiguous()
input2 = input2.contiguous()
return chamfer_6DFunction.apply(input1, input2)

View file

@ -0,0 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='chamfer_6D',
ext_modules=[
CUDAExtension('chamfer_6D', [
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
"/".join(__file__.split('/')[:-1] + ['chamfer6D.cu']),
]),
],
cmdclass={
'build_ext': BuildExtension
})

View file

@ -0,0 +1,44 @@
import torch
def pairwise_dist(x, y):
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)
P = rx.t() + ry - 2 * zz
return P
def NN_loss(x, y, dim=0):
dist = pairwise_dist(x, y)
values, indices = dist.min(dim=dim)
return values.mean()
def batched_pairwise_dist(a, b):
x, y = a.double(), b.double()
bs, num_points_x, points_dim = x.size()
bs, num_points_y, points_dim = y.size()
xx = torch.pow(x, 2).sum(2)
yy = torch.pow(y, 2).sum(2)
zz = torch.bmm(x, y.transpose(2, 1))
rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx
ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy
P = rx.transpose(2, 1) + ry - 2 * zz
return P
def distChamfer(a, b):
"""
:param a: Pointclouds Batch x nul_points x dim
:param b: Pointclouds Batch x nul_points x dim
:return:
-closest point on b of points from a
-closest point on a of points from b
-idx of closest point on b of points from a
-idx of closest point on a of points from b
Works for pointcloud of any dimension
"""
P = batched_pairwise_dist(a, b)
return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int()

View file

@ -0,0 +1,17 @@
import torch
def fscore(dist1, dist2, threshold=0.001):
"""
Calculates the F-score between two point clouds with the corresponding threshold value.
:param dist1: Batch, N-Points
:param dist2: Batch, N-Points
:param th: float
:return: fscore, precision, recall
"""
# NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly.
precision_1 = torch.mean((dist1 < threshold).float(), dim=1)
precision_2 = torch.mean((dist2 < threshold).float(), dim=1)
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
fscore[torch.isnan(fscore)] = 0
return fscore, precision_1, precision_2

View file

@ -0,0 +1,69 @@
import torch, time
import chamfer2D.dist_chamfer_2D
import chamfer3D.dist_chamfer_3D
import chamfer5D.dist_chamfer_5D
import chamfer_python
cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist()
cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist()
cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist()
from torch.autograd import Variable
from fscore import fscore
def test_chamfer(distChamfer, dim):
points1 = torch.rand(4, 100, dim).cuda()
points2 = torch.rand(4, 200, dim, requires_grad=True).cuda()
dist1, dist2, idx1, idx2= distChamfer(points1, points2)
loss = torch.sum(dist1)
loss.backward()
mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2)
d1 = (dist1 - mydist1) ** 2
d2 = (dist2 - mydist2) ** 2
assert (
torch.mean(d1) + torch.mean(d2) < 0.00000001
), "chamfer cuda and chamfer normal are not giving the same results"
xd1 = idx1 - myidx1
xd2 = idx2 - myidx2
assert (
torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0
), "chamfer cuda and chamfer normal are not giving the same results"
print(f"fscore :", fscore(dist1, dist2))
print("Unit test passed")
def timings(distChamfer, dim):
p1 = torch.rand(32, 2000, dim).cuda()
p2 = torch.rand(32, 1000, dim).cuda()
print("Timings : Start CUDA version")
start = time.time()
num_it = 100
for i in range(num_it):
points1 = Variable(p1, requires_grad=True)
points2 = Variable(p2)
mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2)
loss = torch.sum(mydist1)
loss.backward()
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
print("Timings : Start Pythonic version")
start = time.time()
for i in range(num_it):
points1 = Variable(p1, requires_grad=True)
points2 = Variable(p2)
mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2)
loss = torch.sum(mydist1)
loss.backward()
print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.")
dims = [2,3,5]
for i,cham in enumerate([cham2D, cham3D, cham5D]):
print(f"testing Chamfer {dims[i]}D")
test_chamfer(cham, dims[i])
timings(cham, dims[i])

5
third_party/PyTorchEMD/.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
__pycache__
build
dist
emd_ext.egg-info
*.so

34
third_party/PyTorchEMD/README.md vendored Normal file
View file

@ -0,0 +1,34 @@
* adapted from https://github.com/daerduoCarey/PyTorchEMD
---------------------------------
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
## Dependency
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
## Usage
First compile using
python setup.py install
Then, copy the lib file out to the main directory,
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
Then, you can use it by simply
from emd import earth_mover_distance
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
Check `test_emd_loss.py` for example.
## Author
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
## License
MIT

0
third_party/PyTorchEMD/__init__.py vendored Executable file
View file

21
third_party/PyTorchEMD/backend.py vendored Executable file
View file

@ -0,0 +1,21 @@
import os
import time
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(os.path.join(_src_path, 'build_dynamic')):
os.makedirs(os.path.join(_src_path, 'build_dynamic'))
tic = time.time()
emd_cuda_dynamic = load(name='emd_ext',
extra_cflags=['-O3', '-std=c++17'],
## build_directory=os.path.join(_src_path, 'build_dynamic'),
verbose=True,
sources=[
os.path.join(_src_path, f) for f in [
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
]
])
print('load emd_ext time: {:.3f}s'.format(time.time() - tic))
__all__ = ['emd_cuda_dynamic']

29
third_party/PyTorchEMD/cuda/emd.cpp vendored Executable file
View file

@ -0,0 +1,29 @@
#ifndef _EMD
#define _EMD
#include <vector>
#include <torch/extension.h>
//CUDA declarations
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2);
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
}
#endif

View file

@ -0,0 +1,398 @@
/**********************************
* Original Author: Haoqiang Fan
* Modified by: Kaichun Mo
*********************************/
#ifndef _EMD_KERNEL
#define _EMD_KERNEL
#include <cmath>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh> // at::cuda::getApplyGrid
#include <THC/THC.h>
#define CHECK_INPUT(x)
/********************************
* Forward kernel for approxmatch
*********************************/
template<typename scalar_t>
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
scalar_t multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ scalar_t buf[Block*4];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
match[i*n*m+j]=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x)
remainL[j]=multiL;
for (int j=threadIdx.x;j<m;j+=blockDim.x)
remainR[j]=multiR;
__syncthreads();
for (int j=7;j>=-2;j--){
scalar_t level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
scalar_t suml=1e-9f;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
scalar_t x2=xyz2[i*m*3+l0*3+l*3+0];
scalar_t y2=xyz2[i*m*3+l0*3+l*3+1];
scalar_t z2=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+0]=x2;
buf[l*4+1]=y2;
buf[l*4+2]=z2;
buf[l*4+3]=remainR[l0+l];
}
__syncthreads();
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*4+0];
scalar_t y2=buf[l*4+1];
scalar_t z2=buf[l*4+2];
scalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
scalar_t w=__expf(d)*buf[l*4+3];
suml+=w;
}
__syncthreads();
}
if (k<n)
ratioL[k]=remainL[k]/suml;
}
__syncthreads();
for (int l0=0;l0<m;l0+=blockDim.x){
int l=l0+threadIdx.x;
scalar_t x2=0,y2=0,z2=0;
if (l<m){
x2=xyz2[i*m*3+l*3+0];
y2=xyz2[i*m*3+l*3+1];
z2=xyz2[i*m*3+l*3+2];
}
scalar_t sumr=0;
for (int k0=0;k0<n;k0+=Block){
int kend=min(n,k0+Block)-k0;
for (int k=threadIdx.x;k<kend;k+=blockDim.x){
buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
buf[k*4+3]=ratioL[k0+k];
}
__syncthreads();
for (int k=0;k<kend;k++){
scalar_t x1=buf[k*4+0];
scalar_t y1=buf[k*4+1];
scalar_t z1=buf[k*4+2];
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
sumr+=w;
}
__syncthreads();
}
if (l<m){
sumr*=remainR[l];
scalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
ratioR[l]=consumption*remainR[l];
remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
}
}
__syncthreads();
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
scalar_t suml=0;
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend;l+=blockDim.x){
buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
buf[l*4+3]=ratioR[l0+l];
}
__syncthreads();
scalar_t rl=ratioL[k];
if (k<n){
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*4+0];
scalar_t y2=buf[l*4+1];
scalar_t z2=buf[l*4+2];
scalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
match[i*n*m+(l0+l)*n+k]+=w;
suml+=w;
}
}
__syncthreads();
}
if (k<n)
remainL[k]=fmaxf(0.0f,remainL[k]-suml);
}
__syncthreads();
}
}
}
//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){
// approxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);
//}
/* ApproxMatch forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
Output:
match: (B, N2, N1)
*/
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return match;
}
/********************************
* Forward kernel for matchcost
*********************************/
template<typename scalar_t>
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
__shared__ scalar_t allsum[512];
const int Block=1024;
__shared__ scalar_t buf[Block*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
scalar_t subsum=0;
for (int k0=0;k0<n;k0+=blockDim.x){
int k=k0+threadIdx.x;
scalar_t x1=0,y1=0,z1=0;
if (k<n){
x1=xyz1[i*n*3+k*3+0];
y1=xyz1[i*n*3+k*3+1];
z1=xyz1[i*n*3+k*3+2];
}
for (int l0=0;l0<m;l0+=Block){
int lend=min(m,l0+Block)-l0;
for (int l=threadIdx.x;l<lend*3;l+=blockDim.x)
buf[l]=xyz2[i*m*3+l0*3+l];
__syncthreads();
if (k<n){
for (int l=0;l<lend;l++){
scalar_t x2=buf[l*3+0];
scalar_t y2=buf[l*3+1];
scalar_t z2=buf[l*3+2];
scalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
subsum+=d*match[i*n*m+(l0+l)*n+k];
}
}
__syncthreads();
}
}
allsum[threadIdx.x]=subsum;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
allsum[threadIdx.x]+=allsum[threadIdx.x+j];
}
}
if (threadIdx.x==0)
out[i]=allsum[0];
__syncthreads();
}
}
//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){
// matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
//}
/* MatchCost forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
cost: (B)
*/
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
__shared__ scalar_t sum_grad[256*3];
for (int i=blockIdx.x;i<b;i+=gridDim.x){
int kbeg=m*blockIdx.y/gridDim.y;
int kend=m*(blockIdx.y+1)/gridDim.y;
for (int k=kbeg;k<kend;k++){
scalar_t x2=xyz2[(i*m+k)*3+0];
scalar_t y2=xyz2[(i*m+k)*3+1];
scalar_t z2=xyz2[(i*m+k)*3+2];
scalar_t subsumx=0,subsumy=0,subsumz=0;
for (int j=threadIdx.x;j<n;j+=blockDim.x){
scalar_t x1=x2-xyz1[(i*n+j)*3+0];
scalar_t y1=y2-xyz1[(i*n+j)*3+1];
scalar_t z1=z2-xyz1[(i*n+j)*3+2];
scalar_t d=match[i*n*m+k*n+j]*2;
subsumx+=x1*d;
subsumy+=y1*d;
subsumz+=z1*d;
}
sum_grad[threadIdx.x*3+0]=subsumx;
sum_grad[threadIdx.x*3+1]=subsumy;
sum_grad[threadIdx.x*3+2]=subsumz;
for (int j=1;j<blockDim.x;j<<=1){
__syncthreads();
int j1=threadIdx.x;
int j2=threadIdx.x+j;
if ((j1&j)==0 && j2<blockDim.x){
sum_grad[j1*3+0]+=sum_grad[j2*3+0];
sum_grad[j1*3+1]+=sum_grad[j2*3+1];
sum_grad[j1*3+2]+=sum_grad[j2*3+2];
}
}
if (threadIdx.x==0){
grad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];
grad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];
grad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];
}
__syncthreads();
}
}
}
/********************************
* matchcostgrad1 kernel
*********************************/
template<typename scalar_t>
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
for (int i=blockIdx.x;i<b;i+=gridDim.x){
for (int l=threadIdx.x;l<n;l+=blockDim.x){
scalar_t x1=xyz1[i*n*3+l*3+0];
scalar_t y1=xyz1[i*n*3+l*3+1];
scalar_t z1=xyz1[i*n*3+l*3+2];
scalar_t dx=0,dy=0,dz=0;
for (int k=0;k<m;k++){
scalar_t x2=xyz2[i*m*3+k*3+0];
scalar_t y2=xyz2[i*m*3+k*3+1];
scalar_t z2=xyz2[i*m*3+k*3+2];
scalar_t d=match[i*n*m+k*n+l]*2;
dx+=(x1-x2)*d;
dy+=(y1-y2)*d;
dz+=(z1-z2)*d;
}
grad1[i*n*3+l*3+0]=dx*grad_cost[i];
grad1[i*n*3+l*3+1]=dy*grad_cost[i];
grad1[i*n*3+l*3+2]=dz*grad_cost[i];
}
}
}
//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){
// matchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);
// matchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/* MatchCost backward interface
Input:
grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
grad1: (B, N1, 3)
grad2: (B, N2, 3)
*/
std::vector<at::Tensor> MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());
matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());
}));
THCudaCheck(cudaGetLastError());
return std::vector<at::Tensor>({grad1, grad2});
}
#endif

52
third_party/PyTorchEMD/emd.py vendored Executable file
View file

@ -0,0 +1,52 @@
import torch
# from backend import emd_cuda_dynamic as emd_cuda # jit compiling
from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
@custom_bwd
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
# xyz1: B,N,3
N = xyz1.shape[1]
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N)
return cost

9
third_party/PyTorchEMD/emd_cuda.py vendored Normal file
View file

@ -0,0 +1,9 @@
def __bootstrap__():
global __bootstrap__, __loader__, __file__
import sys, pkg_resources, importlib.util
__file__ = pkg_resources.resource_filename(__name__, 'emd_cuda.cpython-38-x86_64-linux-gnu.so')
__loader__ = None; del __bootstrap__, __loader__
spec = importlib.util.spec_from_file_location(__name__,__file__)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
__bootstrap__()

45
third_party/PyTorchEMD/emd_nograd.py vendored Normal file
View file

@ -0,0 +1,45 @@
import torch
#import emd_cuda
# from evaluation.PyTorchEMD import emd_cuda
from third_party.PyTorchEMD.backend import emd_cuda_dynamic as emd_cuda
class EarthMoverDistanceFunctionNoGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
# ctx.save_for_backward(xyz1, xyz2, match)
return cost
def earth_mover_distance_nograd(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
# xyz1: B,N,3
N = xyz1.shape[1]
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
#print('xyz1: ', xyz1.shape, xyz2.shape, xyz1.min(), xyz1.max(), xyz2.min(), xyz2.max())
cost = EarthMoverDistanceFunctionNoGrad.apply(xyz1, xyz2) / float(N)
return cost

49
third_party/PyTorchEMD/emd_static.py vendored Executable file
View file

@ -0,0 +1,49 @@
import torch
import emd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
def earth_mover_distance(xyz1, xyz2, transpose=True):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, 3, n1)
xyz2 (torch.Tensor): (b, 3, n1)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
if xyz1.dim() == 2:
xyz1 = xyz1.unsqueeze(0)
if xyz2.dim() == 2:
xyz2 = xyz2.unsqueeze(0)
if transpose:
xyz1 = xyz1.transpose(1, 2)
xyz2 = xyz2.transpose(1, 2)
# xyz1: B,N,3
N = xyz1.shape[1]
assert(xyz1.shape[-1] == 3), f'require it to be B,N,3; get: {xyz1.shape}'
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2) / float(N)
return cost

27
third_party/PyTorchEMD/setup.py vendored Executable file
View file

@ -0,0 +1,27 @@
"""Setup extension
Notes:
If extra_compile_args is provided, you need to provide different instances for different extensions.
Refer to https://github.com/pytorch/pytorch/issues/20169
"""
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
ext_modules=[
CUDAExtension(
name='emd_cuda',
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
),
],
cmdclass={
'build_ext': BuildExtension
})

44
third_party/PyTorchEMD/test_emd_loss.py vendored Normal file
View file

@ -0,0 +1,44 @@
import torch
import numpy as np
import time
from emd import earth_mover_distance
# gt
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist.backward()
print(p1.grad)
print(p2.grad)
# emd
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
d = earth_mover_distance(p1, p2, transpose=False)
print(d)
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)

21
third_party/pvcnn/LICENSE vendored Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018 Zhijian Liu, Haotian Tang, Yujun Lin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

2
third_party/pvcnn/README.md vendored Normal file
View file

@ -0,0 +1,2 @@
* all the code under this folder is based on the code under https://github.com/mit-han-lab/pvcnn/tree/master/modules

View file

@ -0,0 +1,7 @@
from third_party.pvcnn.functional.ball_query import ball_query
from third_party.pvcnn.functional.devoxelization import trilinear_devoxelize
from third_party.pvcnn.functional.grouping import grouping
from third_party.pvcnn.functional.interpolatation import nearest_neighbor_interpolate
from third_party.pvcnn.functional.loss import kl_loss, huber_loss
from third_party.pvcnn.functional.sampling import gather, furthest_point_sample, logits_mask
from third_party.pvcnn.functional.voxelization import avg_voxelize

29
third_party/pvcnn/functional/backend.py vendored Normal file
View file

@ -0,0 +1,29 @@
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(os.path.join(_src_path, 'build')):
os.makedirs(os.path.join(_src_path, 'build'))
_backend = load(name='_pvcnn_backend',
extra_cflags=['-O3', '-std=c++17'],
verbose=True,
sources=[
os.path.join(_src_path, 'src', f) for f in [
'ball_query/ball_query.cpp',
'ball_query/ball_query.cu',
'grouping/grouping.cpp',
'grouping/grouping.cu',
'interpolate/neighbor_interpolate.cpp',
'interpolate/neighbor_interpolate.cu',
'interpolate/trilinear_devox.cpp',
'interpolate/trilinear_devox.cu',
'sampling/sampling.cpp',
'sampling/sampling.cu',
'voxelization/vox.cpp',
'voxelization/vox.cu',
'bindings.cpp',
]
])
__all__ = ['_backend']

View file

@ -0,0 +1,20 @@
from torch.autograd import Function
from third_party.pvcnn.functional.backend import _backend
__all__ = ['ball_query']
def ball_query(centers_coords, points_coords, radius, num_neighbors):
"""
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param radius: float, radius of ball query
:param num_neighbors: int, maximum number of neighbors
:return:
neighbor_indices: indices of neighbors, IntTensor[B, M, U]
"""
centers_coords = centers_coords[:,:3].contiguous()
points_coords = points_coords[:,:3].contiguous()
return _backend.ball_query(centers_coords, points_coords, radius,
num_neighbors)

View file

@ -0,0 +1,45 @@
from torch.autograd import Function
from third_party.pvcnn.functional.backend import _backend
__all__ = ['trilinear_devoxelize']
class TrilinearDevoxelization(Function):
@staticmethod
def forward(ctx, features, coords, resolution, is_training=True):
"""
:param ctx:
:param coords: the coordinates of points, FloatTensor[B, 3, N]
:param features: FloatTensor[B, C, R, R, R]
:param resolution: int, the voxel resolution
:param is_training: bool, training mode
:return:
FloatTensor[B, C, N]
"""
B, C = features.shape[:2]
features = features.contiguous().view(B, C, -1)
coords = coords[:,:3].contiguous()
outs, inds, wgts = _backend.trilinear_devoxelize_forward(
resolution, is_training, coords, features)
if is_training:
ctx.save_for_backward(inds, wgts)
ctx.r = resolution
return outs
@staticmethod
def backward(ctx, grad_output):
"""
:param ctx:
:param grad_output: gradient of outputs, FloatTensor[B, C, N]
:return:
gradient of inputs, FloatTensor[B, C, R, R, R]
"""
inds, wgts = ctx.saved_tensors
grad_inputs = _backend.trilinear_devoxelize_backward(
grad_output.contiguous(), inds, wgts, ctx.r)
return grad_inputs.view(grad_output.size(0), grad_output.size(1),
ctx.r, ctx.r, ctx.r), None, None, None
trilinear_devoxelize = TrilinearDevoxelization.apply

View file

@ -0,0 +1,33 @@
from torch.autograd import Function
# from modules.functional.backend import _backend
from third_party.pvcnn.functional.backend import _backend
__all__ = ['grouping']
class Grouping(Function):
@staticmethod
def forward(ctx, features, indices):
"""
:param ctx:
:param features: features of points, FloatTensor[B, C, N]
:param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
:return:
grouped_features: grouped features, FloatTensor[B, C, M, U]
"""
features = features.contiguous()
indices = indices.contiguous()
ctx.save_for_backward(indices)
ctx.num_points = features.size(-1)
return _backend.grouping_forward(features, indices)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_features = _backend.grouping_backward(grad_output.contiguous(),
indices, ctx.num_points)
return grad_features, None
grouping = Grouping.apply

View file

@ -0,0 +1,54 @@
from torch.autograd import Function
# from modules.functional.backend import _backend
from third_party.pvcnn.functional.backend import _backend
import torch
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
__all__ = ['nearest_neighbor_interpolate']
class NeighborInterpolation(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, points_coords, centers_coords, centers_features):
"""
:param ctx:
:param points_coords: coordinates of points, FloatTensor[B, 3, N]
:param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
:param centers_features: features of centers, FloatTensor[B, C, M]
:return:
points_features: features of points, FloatTensor[B, C, N]
"""
centers_coords = centers_coords[:,:3].contiguous()
points_coords = points_coords[:,:3].contiguous()
centers_features = centers_features.contiguous()
points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
points_coords, centers_coords, centers_features)
ctx.save_for_backward(indices, weights)
ctx.num_centers = centers_coords.size(-1)
return points_features
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
indices, weights = ctx.saved_tensors
grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
grad_output.contiguous(), indices, weights, ctx.num_centers)
return None, None, grad_centers_features
nearest_neighbor_interpolate = NeighborInterpolation.apply
#def nearest_neighbor_interpolate(points_coords, centers_coords, centers_features):
# # points_coords: (B,6, 64)
# # centers_coords: (B,6, 16)
# # centers_features: (B,128,16)
# # interpolated_features: (B,128,64)
# B = points_coords.shape[0]
# D = centers_features.shape[1]
# N = points_coords.shape[2]
# output = torch.zeros(B,D,N).to(points_coords.shape)
# for b in range(B):
# for n in range(N):
# points_coords_cur = points_coords

18
third_party/pvcnn/functional/loss.py vendored Normal file
View file

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
__all__ = ['kl_loss', 'huber_loss']
def kl_loss(x, y):
x = F.softmax(x.detach(), dim=1)
y = F.log_softmax(y, dim=1)
return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1))
def huber_loss(error, delta):
abs_error = torch.abs(error)
quadratic = torch.min(abs_error,
torch.full_like(abs_error, fill_value=delta))
losses = 0.5 * (quadratic**2) + delta * (abs_error - quadratic)
return torch.mean(losses)

100
third_party/pvcnn/functional/sampling.py vendored Normal file
View file

@ -0,0 +1,100 @@
import numpy as np
import torch
from torch.autograd import Function
# from modules.functional.backend import _backend
from third_party.pvcnn.functional.backend import _backend
__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
class Gather(Function):
@staticmethod
def forward(ctx, features, indices):
"""
Gather
:param ctx:
:param features: features of points, FloatTensor[B, C, N]
:param indices: centers' indices in points, IntTensor[b, m]
:return:
centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
"""
features = features.contiguous()
indices = indices.int().contiguous()
ctx.save_for_backward(indices)
ctx.num_points = features.size(-1)
return _backend.gather_features_forward(features, indices)
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_features = _backend.gather_features_backward(
grad_output.contiguous(), indices, ctx.num_points)
return grad_features, None
gather = Gather.apply
def furthest_point_sample(coords, num_samples, normals=None):
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance to the sampled point set
:param coords: coordinates of points, FloatTensor[B, 3, N]
:param num_samples: int, M
:return:
center_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
"""
assert(len(coords.shape) == 3 and coords.shape[1] == 3), f'expect input as B,3,N; get: {coords.shape}'
coords = coords.contiguous()
indices = _backend.furthest_point_sampling(coords, num_samples)
centers_coords = gather(coords, indices)
if normals is not None:
center_normals = gather(normals, indices)
return centers_coords if normals is None else (centers_coords, center_normals)
def logits_mask(coords, logits, num_points_per_object):
"""
Use logits to sample points
:param coords: coords of points, FloatTensor[B, 3, N]
:param logits: binary classification logits, FloatTensor[B, 2, N]
:param num_points_per_object: M, #points per object after masking, int
:return:
selected_coords: FloatTensor[B, 3, M]
masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
mask: mask to select points, BoolTensor[B, N]
"""
batch_size, _, num_points = coords.shape
mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(
num_candidates, torch.ones_like(num_candidates)).float() # [B, C]
selected_indices = torch.zeros((batch_size, num_points_per_object),
device=coords.device,
dtype=torch.int32)
for i in range(batch_size):
current_mask = mask[i] # [N]
current_candidates = current_mask.nonzero().view(-1)
current_num_candidates = current_candidates.numel()
if current_num_candidates >= num_points_per_object:
choices = np.random.choice(current_num_candidates,
num_points_per_object,
replace=False)
selected_indices[i] = current_candidates[choices]
elif current_num_candidates > 0:
choices = np.concatenate([
np.arange(current_num_candidates).repeat(
num_points_per_object // current_num_candidates),
np.random.choice(current_num_candidates,
num_points_per_object %
current_num_candidates,
replace=False)
])
np.random.shuffle(choices)
selected_indices[i] = current_candidates[choices]
selected_coords = gather(
masked_coords - masked_coords_mean.view(batch_size, -1, 1),
selected_indices)
return selected_coords, masked_coords_mean, mask

View file

@ -0,0 +1,30 @@
#include "ball_query.hpp"
#include "ball_query.cuh"
#include "../utils.hpp"
at::Tensor ball_query_forward(at::Tensor centers_coords,
at::Tensor points_coords, const float radius,
const int num_neighbors) {
CHECK_CUDA(centers_coords);
CHECK_CUDA(points_coords);
CHECK_CONTIGUOUS(centers_coords);
CHECK_CONTIGUOUS(points_coords);
CHECK_IS_FLOAT(centers_coords);
CHECK_IS_FLOAT(points_coords);
int b = centers_coords.size(0);
int m = centers_coords.size(2);
int n = points_coords.size(2);
at::Tensor neighbors_indices = torch::zeros(
{b, m, num_neighbors},
at::device(centers_coords.device()).dtype(at::ScalarType::Int));
ball_query(b, n, m, radius * radius, num_neighbors,
centers_coords.data_ptr<float>(),
points_coords.data_ptr<float>(),
neighbors_indices.data_ptr<int>());
return neighbors_indices;
}

View file

@ -0,0 +1,59 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: ball query
Args:
b : batch size
n : number of points in point clouds
m : number of query centers
r2 : ball query radius ** 2
u : maximum number of neighbors
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
points_coords : coordinates of points, FloatTensor[b, 3, n]
neighbors_indices : neighbor indices in points, IntTensor[b, m, u]
*/
__global__ void ball_query_kernel(int b, int n, int m, float r2, int u,
const float *__restrict__ centers_coords,
const float *__restrict__ points_coords,
int *__restrict__ neighbors_indices) {
int batch_index = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
points_coords += batch_index * n * 3;
centers_coords += batch_index * m * 3;
neighbors_indices += batch_index * m * u;
for (int j = index; j < m; j += stride) {
float center_x = centers_coords[j];
float center_y = centers_coords[j + m];
float center_z = centers_coords[j + m + m];
for (int k = 0, cnt = 0; k < n && cnt < u; ++k) {
float dx = center_x - points_coords[k];
float dy = center_y - points_coords[k + n];
float dz = center_z - points_coords[k + n + n];
float d2 = dx * dx + dy * dy + dz * dz;
if (d2 < r2) {
if (cnt == 0) {
for (int v = 0; v < u; ++v) {
neighbors_indices[j * u + v] = k;
}
}
neighbors_indices[j * u + cnt] = k;
++cnt;
}
}
}
}
void ball_query(int b, int n, int m, float r2, int u,
const float *centers_coords, const float *points_coords,
int *neighbors_indices) {
ball_query_kernel<<<b, optimal_num_threads(m), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, n, m, r2, u, centers_coords, points_coords, neighbors_indices);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,8 @@
#ifndef _BALL_QUERY_CUH
#define _BALL_QUERY_CUH
void ball_query(int b, int n, int m, float r2, int u,
const float *centers_coords, const float *points_coords,
int *neighbors_indices);
#endif

View file

@ -0,0 +1,10 @@
#ifndef _BALL_QUERY_HPP
#define _BALL_QUERY_HPP
#include <torch/extension.h>
at::Tensor ball_query_forward(at::Tensor centers_coords,
at::Tensor points_coords, const float radius,
const int num_neighbors);
#endif

View file

@ -0,0 +1,37 @@
#include <pybind11/pybind11.h>
#include "ball_query/ball_query.hpp"
#include "grouping/grouping.hpp"
#include "interpolate/neighbor_interpolate.hpp"
#include "interpolate/trilinear_devox.hpp"
#include "sampling/sampling.hpp"
#include "voxelization/vox.hpp"
PYBIND11_MODULE(_pvcnn_backend, m) {
m.def("gather_features_forward", &gather_features_forward,
"Gather Centers' Features forward (CUDA)");
m.def("gather_features_backward", &gather_features_backward,
"Gather Centers' Features backward (CUDA)");
m.def("furthest_point_sampling", &furthest_point_sampling_forward,
"Furthest Point Sampling (CUDA)");
m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)");
m.def("grouping_forward", &grouping_forward,
"Grouping Features forward (CUDA)");
m.def("grouping_backward", &grouping_backward,
"Grouping Features backward (CUDA)");
m.def("three_nearest_neighbors_interpolate_forward",
&three_nearest_neighbors_interpolate_forward,
"3 Nearest Neighbors Interpolate forward (CUDA)");
m.def("three_nearest_neighbors_interpolate_backward",
&three_nearest_neighbors_interpolate_backward,
"3 Nearest Neighbors Interpolate backward (CUDA)");
m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward,
"Trilinear Devoxelization forward (CUDA)");
m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward,
"Trilinear Devoxelization backward (CUDA)");
m.def("avg_voxelize_forward", &avg_voxelize_forward,
"Voxelization forward with average pooling (CUDA)");
m.def("avg_voxelize_backward", &avg_voxelize_backward,
"Voxelization backward (CUDA)");
}

View file

@ -0,0 +1,39 @@
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#define MAXIMUM_THREADS 512
inline int optimal_num_threads(int work_size) {
const int pow_2 = std::log2(static_cast<double>(work_size));
return max(min(1 << pow_2, MAXIMUM_THREADS), 1);
}
inline dim3 optimal_block_config(int x, int y) {
const int x_threads = optimal_num_threads(x);
const int y_threads =
max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1);
dim3 block_config(x_threads, y_threads, 1);
return block_config;
}
#define CUDA_CHECK_ERRORS() \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
__FILE__); \
exit(-1); \
} \
}
#endif

View file

@ -0,0 +1,44 @@
#include "grouping.hpp"
#include "grouping.cuh"
#include "../utils.hpp"
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) {
CHECK_CUDA(features);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(features);
CHECK_IS_INT(indices);
int b = features.size(0);
int c = features.size(1);
int n = features.size(2);
int m = indices.size(1);
int u = indices.size(2);
at::Tensor output = torch::zeros(
{b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float));
grouping(b, c, n, m, u, features.data_ptr<float>(), indices.data_ptr<int>(),
output.data_ptr<float>());
return output;
}
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
const int n) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
int b = grad_y.size(0);
int c = grad_y.size(1);
int m = indices.size(1);
int u = indices.size(2);
at::Tensor grad_x = torch::zeros(
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
grouping_grad(b, c, n, m, u, grad_y.data_ptr<float>(),
indices.data_ptr<int>(), grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,85 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: grouping features of neighbors (forward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query centers
u : maximum number of neighbors
features: points' features, FloatTensor[b, c, n]
indices : neighbor indices in points, IntTensor[b, m, u]
out : gathered features, FloatTensor[b, c, m, u]
*/
__global__ void grouping_kernel(int b, int c, int n, int m, int u,
const float *__restrict__ features,
const int *__restrict__ indices,
float *__restrict__ out) {
int batch_index = blockIdx.x;
features += batch_index * n * c;
indices += batch_index * m * u;
out += batch_index * m * u * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * m; i += stride) {
const int l = i / m;
const int j = i % m;
for (int k = 0; k < u; ++k) {
out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]];
}
}
}
void grouping(int b, int c, int n, int m, int u, const float *features,
const int *indices, float *out) {
grouping_kernel<<<b, optimal_block_config(m, c), 0,
at::cuda::getCurrentCUDAStream()>>>(b, c, n, m, u, features,
indices, out);
CUDA_CHECK_ERRORS();
}
/*
Function: grouping features of neighbors (backward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query centers
u : maximum number of neighbors
grad_y : grad of gathered features, FloatTensor[b, c, m, u]
indices : neighbor indices in points, IntTensor[b, m, u]
grad_x: grad of points' features, FloatTensor[b, c, n]
*/
__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u,
const float *__restrict__ grad_y,
const int *__restrict__ indices,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
grad_y += batch_index * m * u * c;
indices += batch_index * m * u;
grad_x += batch_index * n * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * m; i += stride) {
const int l = i / m;
const int j = i % m;
for (int k = 0; k < u; ++k) {
atomicAdd(grad_x + l * n + indices[j * u + k],
grad_y[(l * m + j) * u + k]);
}
}
}
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
const int *indices, float *grad_x) {
grouping_grad_kernel<<<b, optimal_block_config(m, c), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, u, grad_y, indices, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,9 @@
#ifndef _GROUPING_CUH
#define _GROUPING_CUH
void grouping(int b, int c, int n, int m, int u, const float *features,
const int *indices, float *out);
void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
const int *indices, float *grad_x);
#endif

View file

@ -0,0 +1,10 @@
#ifndef _GROUPING_HPP
#define _GROUPING_HPP
#include <torch/extension.h>
at::Tensor grouping_forward(at::Tensor features, at::Tensor indices);
at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
const int n);
#endif

View file

@ -0,0 +1,65 @@
#include "neighbor_interpolate.hpp"
#include "neighbor_interpolate.cuh"
#include "../utils.hpp"
std::vector<at::Tensor>
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
at::Tensor centers_coords,
at::Tensor centers_features) {
CHECK_CUDA(points_coords);
CHECK_CUDA(centers_coords);
CHECK_CUDA(centers_features);
CHECK_CONTIGUOUS(points_coords);
CHECK_CONTIGUOUS(centers_coords);
CHECK_CONTIGUOUS(centers_features);
CHECK_IS_FLOAT(points_coords);
CHECK_IS_FLOAT(centers_coords);
CHECK_IS_FLOAT(centers_features);
int b = centers_features.size(0);
int c = centers_features.size(1);
int m = centers_features.size(2);
int n = points_coords.size(2);
at::Tensor indices = torch::zeros(
{b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int));
at::Tensor weights = torch::zeros(
{b, 3, n},
at::device(points_coords.device()).dtype(at::ScalarType::Float));
at::Tensor output = torch::zeros(
{b, c, n},
at::device(centers_features.device()).dtype(at::ScalarType::Float));
three_nearest_neighbors_interpolate(
b, c, m, n, points_coords.data_ptr<float>(),
centers_coords.data_ptr<float>(), centers_features.data_ptr<float>(),
indices.data_ptr<int>(), weights.data_ptr<float>(),
output.data_ptr<float>());
return {output, indices, weights};
}
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
at::Tensor indices,
at::Tensor weights,
const int m) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CUDA(weights);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_CONTIGUOUS(weights);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
CHECK_IS_FLOAT(weights);
int b = grad_y.size(0);
int c = grad_y.size(1);
int n = grad_y.size(2);
at::Tensor grad_x = torch::zeros(
{b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
three_nearest_neighbors_interpolate_grad(
b, c, n, m, grad_y.data_ptr<float>(), indices.data_ptr<int>(),
weights.data_ptr<float>(), grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,181 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: three nearest neighbors
Args:
b : batch size
n : number of points in point clouds
m : number of query centers
points_coords : coordinates of points, FloatTensor[b, 3, n]
centers_coords: coordinates of centers, FloatTensor[b, 3, m]
weights : weights of nearest 3 centers to the point,
FloatTensor[b, 3, n]
indices : indices of nearest 3 centers to the point,
IntTensor[b, 3, n]
*/
__global__ void three_nearest_neighbors_kernel(
int b, int n, int m, const float *__restrict__ points_coords,
const float *__restrict__ centers_coords, float *__restrict__ weights,
int *__restrict__ indices) {
int batch_index = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
points_coords += batch_index * 3 * n;
weights += batch_index * 3 * n;
indices += batch_index * 3 * n;
centers_coords += batch_index * 3 * m;
for (int j = index; j < n; j += stride) {
float ux = points_coords[j];
float uy = points_coords[j + n];
float uz = points_coords[j + n + n];
double best0 = 1e40, best1 = 1e40, best2 = 1e40;
int besti0 = 0, besti1 = 0, besti2 = 0;
for (int k = 0; k < m; ++k) {
float x = centers_coords[k];
float y = centers_coords[k + m];
float z = centers_coords[k + m + m];
float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
if (d < best2) {
best2 = d;
besti2 = k;
if (d < best1) {
best2 = best1;
besti2 = besti1;
best1 = d;
besti1 = k;
if (d < best0) {
best1 = best0;
besti1 = besti0;
best0 = d;
besti0 = k;
}
}
}
}
best0 = max(min(1e10f, best0), 1e-10f);
best1 = max(min(1e10f, best1), 1e-10f);
best2 = max(min(1e10f, best2), 1e-10f);
float d0d1 = best0 * best1;
float d0d2 = best0 * best2;
float d1d2 = best1 * best2;
float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2);
weights[j] = d1d2 * d0d1d2;
indices[j] = besti0;
weights[j + n] = d0d2 * d0d1d2;
indices[j + n] = besti1;
weights[j + n + n] = d0d1 * d0d1d2;
indices[j + n + n] = besti2;
}
}
/*
Function: interpolate three nearest neighbors (forward)
Args:
b : batch size
c : #channels of features
m : number of query centers
n : number of points in point clouds
centers_features: features of centers, FloatTensor[b, c, m]
indices : indices of nearest 3 centers to the point,
IntTensor[b, 3, n]
weights : weights for interpolation, FloatTensor[b, 3, n]
out : features of points, FloatTensor[b, c, n]
*/
__global__ void three_nearest_neighbors_interpolate_kernel(
int b, int c, int m, int n, const float *__restrict__ centers_features,
const int *__restrict__ indices, const float *__restrict__ weights,
float *__restrict__ out) {
int batch_index = blockIdx.x;
centers_features += batch_index * m * c;
indices += batch_index * n * 3;
weights += batch_index * n * 3;
out += batch_index * n * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
float w1 = weights[j];
float w2 = weights[j + n];
float w3 = weights[j + n + n];
int i1 = indices[j];
int i2 = indices[j + n];
int i3 = indices[j + n + n];
out[i] = centers_features[l * m + i1] * w1 +
centers_features[l * m + i2] * w2 +
centers_features[l * m + i3] * w3;
}
}
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
const float *points_coords,
const float *centers_coords,
const float *centers_features,
int *indices, float *weights,
float *out) {
three_nearest_neighbors_kernel<<<b, optimal_num_threads(n), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, n, m, points_coords, centers_coords, weights, indices);
three_nearest_neighbors_interpolate_kernel<<<
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
b, c, m, n, centers_features, indices, weights, out);
CUDA_CHECK_ERRORS();
}
/*
Function: interpolate three nearest neighbors (backward)
Args:
b : batch size
c : #channels of features
m : number of query centers
n : number of points in point clouds
grad_y : grad of features of points, FloatTensor[b, c, n]
indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n]
weights : weights for interpolation, FloatTensor[b, 3, n]
grad_x : grad of features of centers, FloatTensor[b, c, m]
*/
__global__ void three_nearest_neighbors_interpolate_grad_kernel(
int b, int c, int n, int m, const float *__restrict__ grad_y,
const int *__restrict__ indices, const float *__restrict__ weights,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
grad_y += batch_index * n * c;
indices += batch_index * n * 3;
weights += batch_index * n * 3;
grad_x += batch_index * m * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
float w1 = weights[j];
float w2 = weights[j + n];
float w3 = weights[j + n + n];
int i1 = indices[j];
int i2 = indices[j + n];
int i3 = indices[j + n + n];
atomicAdd(grad_x + l * m + i1, grad_y[i] * w1);
atomicAdd(grad_x + l * m + i2, grad_y[i] * w2);
atomicAdd(grad_x + l * m + i3, grad_y[i] * w3);
}
}
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
const float *grad_y,
const int *indices,
const float *weights,
float *grad_x) {
three_nearest_neighbors_interpolate_grad_kernel<<<
b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, grad_y, indices, weights, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,16 @@
#ifndef _NEIGHBOR_INTERPOLATE_CUH
#define _NEIGHBOR_INTERPOLATE_CUH
void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
const float *points_coords,
const float *centers_coords,
const float *centers_features,
int *indices, float *weights,
float *out);
void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
const float *grad_y,
const int *indices,
const float *weights,
float *grad_x);
#endif

View file

@ -0,0 +1,16 @@
#ifndef _NEIGHBOR_INTERPOLATE_HPP
#define _NEIGHBOR_INTERPOLATE_HPP
#include <torch/extension.h>
#include <vector>
std::vector<at::Tensor>
three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
at::Tensor centers_coords,
at::Tensor centers_features);
at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
at::Tensor indices,
at::Tensor weights,
const int m);
#endif

View file

@ -0,0 +1,91 @@
#include "trilinear_devox.hpp"
#include "trilinear_devox.cuh"
#include "../utils.hpp"
/*
Function: trilinear devoxelization (forward)
Args:
r : voxel resolution
trainig : whether is training mode
coords : the coordinates of points, FloatTensor[b, 3, n]
features : features, FloatTensor[b, c, s], s = r ** 3
Return:
outs : outputs, FloatTensor[b, c, n]
inds : the voxel coordinates of point cube, IntTensor[b, 8, n]
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
*/
std::vector<at::Tensor>
trilinear_devoxelize_forward(const int r, const bool is_training,
const at::Tensor coords,
const at::Tensor features) {
CHECK_CUDA(features);
CHECK_CUDA(coords);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(coords);
CHECK_IS_FLOAT(features);
CHECK_IS_FLOAT(coords);
int b = features.size(0);
int c = features.size(1);
int n = coords.size(2);
int r2 = r * r;
int r3 = r2 * r;
at::Tensor outs = torch::zeros(
{b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float));
if (is_training) {
at::Tensor inds = torch::zeros(
{b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int));
at::Tensor wgts = torch::zeros(
{b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float));
trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr<float>(),
features.data_ptr<float>(), inds.data_ptr<int>(),
wgts.data_ptr<float>(), outs.data_ptr<float>());
return {outs, inds, wgts};
} else {
at::Tensor inds = torch::zeros(
{1}, at::device(features.device()).dtype(at::ScalarType::Int));
at::Tensor wgts = torch::zeros(
{1}, at::device(features.device()).dtype(at::ScalarType::Float));
trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr<float>(),
features.data_ptr<float>(), inds.data_ptr<int>(),
wgts.data_ptr<float>(), outs.data_ptr<float>());
return {outs, inds, wgts};
}
}
/*
Function: trilinear devoxelization (backward)
Args:
grad_y : grad outputs, FloatTensor[b, c, n]
indices : the voxel coordinates of point cube, IntTensor[b, 8, n]
weights : weight for trilinear interpolation, FloatTensor[b, 8, n]
r : voxel resolution
Return:
grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3
*/
at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor weights,
const int r) {
CHECK_CUDA(grad_y);
CHECK_CUDA(weights);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(weights);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_FLOAT(weights);
CHECK_IS_INT(indices);
int b = grad_y.size(0);
int c = grad_y.size(1);
int n = grad_y.size(2);
int r3 = r * r * r;
at::Tensor grad_x = torch::zeros(
{b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr<int>(),
weights.data_ptr<float>(), grad_y.data_ptr<float>(),
grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,178 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: trilinear devoxlization (forward)
Args:
b : batch size
c : #channels
n : number of points
r : voxel resolution
r2 : r ** 2
r3 : r ** 3
coords : the coordinates of points, FloatTensor[b, 3, n]
feat : features, FloatTensor[b, c, r3]
inds : the voxel indices of point cube, IntTensor[b, 8, n]
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
outs : outputs, FloatTensor[b, c, n]
*/
__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2,
int r3, bool is_training,
const float *__restrict__ coords,
const float *__restrict__ feat,
int *__restrict__ inds,
float *__restrict__ wgts,
float *__restrict__ outs) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
coords += batch_index * n * 3;
inds += batch_index * n * 8;
wgts += batch_index * n * 8;
feat += batch_index * c * r3;
outs += batch_index * c * n;
for (int i = index; i < n; i += stride) {
float x = coords[i];
float y = coords[i + n];
float z = coords[i + n + n];
float x_lo_f = floorf(x);
float y_lo_f = floorf(y);
float z_lo_f = floorf(z);
float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f)
float y_d_1 = y - y_lo_f;
float z_d_1 = z - z_lo_f;
float x_d_0 = 1.0f - x_d_1;
float y_d_0 = 1.0f - y_d_1;
float z_d_0 = 1.0f - z_d_1;
float wgt000 = x_d_0 * y_d_0 * z_d_0;
float wgt001 = x_d_0 * y_d_0 * z_d_1;
float wgt010 = x_d_0 * y_d_1 * z_d_0;
float wgt011 = x_d_0 * y_d_1 * z_d_1;
float wgt100 = x_d_1 * y_d_0 * z_d_0;
float wgt101 = x_d_1 * y_d_0 * z_d_1;
float wgt110 = x_d_1 * y_d_1 * z_d_0;
float wgt111 = x_d_1 * y_d_1 * z_d_1;
int x_lo = static_cast<int>(x_lo_f);
int y_lo = static_cast<int>(y_lo_f);
int z_lo = static_cast<int>(z_lo_f);
int x_hi = (x_d_1 > 0) ? -1 : 0;
int y_hi = (y_d_1 > 0) ? -1 : 0;
int z_hi = (z_d_1 > 0) ? 1 : 0;
int idx000 = x_lo * r2 + y_lo * r + z_lo;
int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi;
int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo;
int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi;
int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo;
int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi;
int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo;
int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi;
if (is_training) {
wgts[i] = wgt000;
wgts[i + n] = wgt001;
wgts[i + n * 2] = wgt010;
wgts[i + n * 3] = wgt011;
wgts[i + n * 4] = wgt100;
wgts[i + n * 5] = wgt101;
wgts[i + n * 6] = wgt110;
wgts[i + n * 7] = wgt111;
inds[i] = idx000;
inds[i + n] = idx001;
inds[i + n * 2] = idx010;
inds[i + n * 3] = idx011;
inds[i + n * 4] = idx100;
inds[i + n * 5] = idx101;
inds[i + n * 6] = idx110;
inds[i + n * 7] = idx111;
}
for (int j = 0; j < c; j++) {
int jr3 = j * r3;
outs[j * n + i] =
wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] +
wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] +
wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] +
wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111];
}
}
}
/*
Function: trilinear devoxlization (backward)
Args:
b : batch size
c : #channels
n : number of points
r3 : voxel cube size = voxel resolution ** 3
inds : the voxel indices of point cube, IntTensor[b, 8, n]
wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
grad_y : grad outputs, FloatTensor[b, c, n]
grad_x : grad inputs, FloatTensor[b, c, r3]
*/
__global__ void trilinear_devoxelize_grad_kernel(
int b, int c, int n, int r3, const int *__restrict__ inds,
const float *__restrict__ wgts, const float *__restrict__ grad_y,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
inds += batch_index * n * 8;
wgts += batch_index * n * 8;
grad_x += batch_index * c * r3;
grad_y += batch_index * c * n;
for (int i = index; i < n; i += stride) {
int idx000 = inds[i];
int idx001 = inds[i + n];
int idx010 = inds[i + n * 2];
int idx011 = inds[i + n * 3];
int idx100 = inds[i + n * 4];
int idx101 = inds[i + n * 5];
int idx110 = inds[i + n * 6];
int idx111 = inds[i + n * 7];
float wgt000 = wgts[i];
float wgt001 = wgts[i + n];
float wgt010 = wgts[i + n * 2];
float wgt011 = wgts[i + n * 3];
float wgt100 = wgts[i + n * 4];
float wgt101 = wgts[i + n * 5];
float wgt110 = wgts[i + n * 6];
float wgt111 = wgts[i + n * 7];
for (int j = 0; j < c; j++) {
int jr3 = j * r3;
float g = grad_y[j * n + i];
atomicAdd(grad_x + jr3 + idx000, wgt000 * g);
atomicAdd(grad_x + jr3 + idx001, wgt001 * g);
atomicAdd(grad_x + jr3 + idx010, wgt010 * g);
atomicAdd(grad_x + jr3 + idx011, wgt011 * g);
atomicAdd(grad_x + jr3 + idx100, wgt100 * g);
atomicAdd(grad_x + jr3 + idx101, wgt101 * g);
atomicAdd(grad_x + jr3 + idx110, wgt110 * g);
atomicAdd(grad_x + jr3 + idx111, wgt111 * g);
}
}
}
void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
bool training, const float *coords, const float *feat,
int *inds, float *wgts, float *outs) {
trilinear_devoxelize_kernel<<<b, optimal_num_threads(n)>>>(
b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs);
CUDA_CHECK_ERRORS();
}
void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
const float *wgts, const float *grad_y,
float *grad_x) {
trilinear_devoxelize_grad_kernel<<<b, optimal_num_threads(n)>>>(
b, c, n, r3, inds, wgts, grad_y, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,13 @@
#ifndef _TRILINEAR_DEVOX_CUH
#define _TRILINEAR_DEVOX_CUH
// CUDA function declarations
void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
bool is_training, const float *coords,
const float *feat, int *inds, float *wgts,
float *outs);
void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
const float *wgts, const float *grad_y,
float *grad_x);
#endif

View file

@ -0,0 +1,16 @@
#ifndef _TRILINEAR_DEVOX_HPP
#define _TRILINEAR_DEVOX_HPP
#include <torch/torch.h>
#include <vector>
std::vector<at::Tensor> trilinear_devoxelize_forward(const int r,
const bool is_training,
const at::Tensor coords,
const at::Tensor features);
at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor weights, const int r);
#endif

View file

@ -0,0 +1,58 @@
#include "sampling.hpp"
#include "sampling.cuh"
#include "../utils.hpp"
at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) {
CHECK_CUDA(features);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(features);
CHECK_IS_INT(indices);
int b = features.size(0);
int c = features.size(1);
int n = features.size(2);
int m = indices.size(1);
at::Tensor output = torch::zeros(
{b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float));
gather_features(b, c, n, m, features.data_ptr<float>(),
indices.data_ptr<int>(), output.data_ptr<float>());
return output;
}
at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
const int n) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
int b = grad_y.size(0);
int c = grad_y.size(1);
at::Tensor grad_x = torch::zeros(
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr<float>(),
indices.data_ptr<int>(), grad_x.data_ptr<float>());
return grad_x;
}
at::Tensor furthest_point_sampling_forward(at::Tensor coords,
const int num_samples) {
CHECK_CUDA(coords);
CHECK_CONTIGUOUS(coords);
CHECK_IS_FLOAT(coords);
int b = coords.size(0);
int n = coords.size(2);
at::Tensor indices = torch::zeros(
{b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int));
at::Tensor distances = torch::full(
{b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float));
furthest_point_sampling(b, n, num_samples, coords.data_ptr<float>(),
distances.data_ptr<float>(), indices.data_ptr<int>());
return indices;
}

View file

@ -0,0 +1,174 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: gather centers' features (forward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query/sampled centers
features: points' features, FloatTensor[b, c, n]
indices : centers' indices in points, IntTensor[b, m]
out : gathered features, FloatTensor[b, c, m]
*/
__global__ void gather_features_kernel(int b, int c, int n, int m,
const float *__restrict__ features,
const int *__restrict__ indices,
float *__restrict__ out) {
int batch_index = blockIdx.x;
int channel_index = blockIdx.y;
int temp_index = batch_index * c + channel_index;
features += temp_index * n;
indices += batch_index * m;
out += temp_index * m;
for (int j = threadIdx.x; j < m; j += blockDim.x) {
out[j] = features[indices[j]];
}
}
void gather_features(int b, int c, int n, int m, const float *features,
const int *indices, float *out) {
gather_features_kernel<<<dim3(b, c, 1), optimal_num_threads(m), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, features, indices, out);
CUDA_CHECK_ERRORS();
}
/*
Function: gather centers' features (backward)
Args:
b : batch size
c : #channles of features
n : number of points in point clouds
m : number of query/sampled centers
grad_y : grad of gathered features, FloatTensor[b, c, m]
indices : centers' indices in points, IntTensor[b, m]
grad_x : grad of points' features, FloatTensor[b, c, n]
*/
__global__ void gather_features_grad_kernel(int b, int c, int n, int m,
const float *__restrict__ grad_y,
const int *__restrict__ indices,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
int channel_index = blockIdx.y;
int temp_index = batch_index * c + channel_index;
grad_y += temp_index * m;
indices += batch_index * m;
grad_x += temp_index * n;
for (int j = threadIdx.x; j < m; j += blockDim.x) {
atomicAdd(grad_x + indices[j], grad_y[j]);
}
}
void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
const int *indices, float *grad_x) {
gather_features_grad_kernel<<<dim3(b, c, 1), optimal_num_threads(m), 0,
at::cuda::getCurrentCUDAStream()>>>(
b, c, n, m, grad_y, indices, grad_x);
CUDA_CHECK_ERRORS();
}
/*
Function: furthest point sampling
Args:
b : batch size
n : number of points in point clouds
m : number of query/sampled centers
coords : points' coords, FloatTensor[b, 3, n]
distances : minimum distance of a point to the set, IntTensor[b, n]
indices : sampled centers' indices in points, IntTensor[b, m]
*/
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ coords,
float *__restrict__ distances,
int *__restrict__ indices) {
if (m <= 0)
return;
int batch_index = blockIdx.x;
coords += batch_index * n * 3;
distances += batch_index * n;
indices += batch_index * m;
const int BlockSize = 512;
__shared__ float dists[BlockSize];
__shared__ int dists_i[BlockSize];
const int BufferSize = 3072;
__shared__ float buf[BufferSize * 3];
int old = 0;
if (threadIdx.x == 0)
indices[0] = old;
for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) {
buf[j] = coords[j];
buf[j + BufferSize] = coords[j + n];
buf[j + BufferSize + BufferSize] = coords[j + n + n];
}
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0; // best index
float best = -1; // farthest distance
// calculating the distance with the latest sampled point
float x1 = coords[old];
float y1 = coords[old + n];
float z1 = coords[old + n + n];
for (int k = threadIdx.x; k < n; k += blockDim.x) {
// fetch distance at block n, thread k
float td = distances[k];
float x2, y2, z2;
if (k < BufferSize) {
x2 = buf[k];
y2 = buf[k + BufferSize];
z2 = buf[k + BufferSize + BufferSize];
} else {
x2 = coords[k];
y2 = coords[k + n];
z2 = coords[k + n + n];
}
float d =
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, td);
// update "point-to-set" distance
if (d2 != td)
distances[k] = d2;
// update the farthest distance at sample step j
if (d2 > best) {
best = d2;
besti = k;
}
}
dists[threadIdx.x] = best;
dists_i[threadIdx.x] = besti;
for (int u = 0; (1 << u) < blockDim.x; u++) {
__syncthreads();
if (threadIdx.x < (blockDim.x >> (u + 1))) {
int i1 = (threadIdx.x * 2) << u;
int i2 = (threadIdx.x * 2 + 1) << u;
if (dists[i1] < dists[i2]) {
dists[i1] = dists[i2];
dists_i[i1] = dists_i[i2];
}
}
}
__syncthreads();
// finish sample step j; old is the sampled index
old = dists_i[0];
if (threadIdx.x == 0)
indices[j] = old;
}
}
void furthest_point_sampling(int b, int n, int m, const float *coords,
float *distances, int *indices) {
furthest_point_sampling_kernel<<<b, 512>>>(b, n, m, coords, distances,
indices);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,11 @@
#ifndef _SAMPLING_CUH
#define _SAMPLING_CUH
void gather_features(int b, int c, int n, int m, const float *features,
const int *indices, float *out);
void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
const int *indices, float *grad_x);
void furthest_point_sampling(int b, int n, int m, const float *coords,
float *distances, int *indices);
#endif

View file

@ -0,0 +1,12 @@
#ifndef _SAMPLING_HPP
#define _SAMPLING_HPP
#include <torch/extension.h>
at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices);
at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
const int n);
at::Tensor furthest_point_sampling_forward(at::Tensor coords,
const int num_samples);
#endif

View file

@ -0,0 +1,20 @@
#ifndef _UTILS_HPP
#define _UTILS_HPP
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
#x " must be an int tensor")
#define CHECK_IS_FLOAT(x) \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
#x " must be a float tensor")
#endif

View file

@ -0,0 +1,76 @@
#include "vox.hpp"
#include "vox.cuh"
#include "../utils.hpp"
/*
Function: average pool voxelization (forward)
Args:
features: features, FloatTensor[b, c, n]
coords : coords of each point, IntTensor[b, 3, n]
resolution : voxel resolution
Return:
out : outputs, FloatTensor[b, c, s], s = r ** 3
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
*/
std::vector<at::Tensor> avg_voxelize_forward(const at::Tensor features,
const at::Tensor coords,
const int resolution) {
CHECK_CUDA(features);
CHECK_CUDA(coords);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(coords);
CHECK_IS_FLOAT(features);
CHECK_IS_INT(coords);
int b = features.size(0);
int c = features.size(1);
int n = features.size(2);
int r = resolution;
int r2 = r * r;
int r3 = r2 * r;
at::Tensor ind = torch::zeros(
{b, n}, at::device(features.device()).dtype(at::ScalarType::Int));
at::Tensor out = torch::zeros(
{b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float));
at::Tensor cnt = torch::zeros(
{b, r3}, at::device(features.device()).dtype(at::ScalarType::Int));
avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr<int>(),
features.data_ptr<float>(), ind.data_ptr<int>(),
cnt.data_ptr<int>(), out.data_ptr<float>());
return {out, ind, cnt};
}
/*
Function: average pool voxelization (backward)
Args:
grad_y : grad outputs, FloatTensor[b, c, s]
indices: voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
Return:
grad_x : grad inputs, FloatTensor[b, c, n]
*/
at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor cnt) {
CHECK_CUDA(grad_y);
CHECK_CUDA(indices);
CHECK_CUDA(cnt);
CHECK_CONTIGUOUS(grad_y);
CHECK_CONTIGUOUS(indices);
CHECK_CONTIGUOUS(cnt);
CHECK_IS_FLOAT(grad_y);
CHECK_IS_INT(indices);
CHECK_IS_INT(cnt);
int b = grad_y.size(0);
int c = grad_y.size(1);
int s = grad_y.size(2);
int n = indices.size(1);
at::Tensor grad_x = torch::zeros(
{b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
avg_voxelize_grad(b, c, n, s, indices.data_ptr<int>(), cnt.data_ptr<int>(),
grad_y.data_ptr<float>(), grad_x.data_ptr<float>());
return grad_x;
}

View file

@ -0,0 +1,126 @@
#include <stdio.h>
#include <stdlib.h>
#include "../cuda_utils.cuh"
/*
Function: get how many points in each voxel grid
Args:
b : batch size
n : number of points
r : voxel resolution
r2 : = r * r
r3 : s, voxel cube size = r ** 3
coords : coords of each point, IntTensor[b, 3, n]
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
*/
__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3,
const int *__restrict__ coords,
int *__restrict__ ind, int *cnt) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
coords += batch_index * n * 3;
ind += batch_index * n;
cnt += batch_index * r3;
for (int i = index; i < n; i += stride) {
// if (ind[i] == -1)
// continue;
ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n];
atomicAdd(cnt + ind[i], 1);
}
}
/*
Function: average pool voxelization (forward)
Args:
b : batch size
c : #channels
n : number of points
s : voxel cube size = voxel resolution ** 3
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
feat: features, FloatTensor[b, c, n]
out : outputs, FloatTensor[b, c, s]
*/
__global__ void avg_voxelize_kernel(int b, int c, int n, int s,
const int *__restrict__ ind,
const int *__restrict__ cnt,
const float *__restrict__ feat,
float *__restrict__ out) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
ind += batch_index * n;
feat += batch_index * c * n;
out += batch_index * c * s;
cnt += batch_index * s;
for (int i = index; i < n; i += stride) {
int pos = ind[i];
// if (pos == -1)
// continue;
int cur_cnt = cnt[pos];
if (cur_cnt > 0) {
float div_cur_cnt = 1.0 / static_cast<float>(cur_cnt);
for (int j = 0; j < c; j++) {
atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt);
}
}
}
}
/*
Function: average pool voxelization (backward)
Args:
b : batch size
c : #channels
n : number of points
r3 : voxel cube size = voxel resolution ** 3
ind : voxel index of each point, IntTensor[b, n]
cnt : #points in each voxel index, IntTensor[b, s]
grad_y : grad outputs, FloatTensor[b, c, s]
grad_x : grad inputs, FloatTensor[b, c, n]
*/
__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3,
const int *__restrict__ ind,
const int *__restrict__ cnt,
const float *__restrict__ grad_y,
float *__restrict__ grad_x) {
int batch_index = blockIdx.x;
int stride = blockDim.x;
int index = threadIdx.x;
ind += batch_index * n;
grad_x += batch_index * c * n;
grad_y += batch_index * c * r3;
cnt += batch_index * r3;
for (int i = index; i < n; i += stride) {
int pos = ind[i];
// if (pos == -1)
// continue;
int cur_cnt = cnt[pos];
if (cur_cnt > 0) {
float div_cur_cnt = 1.0 / static_cast<float>(cur_cnt);
for (int j = 0; j < c; j++) {
atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt);
}
}
}
}
void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
const float *feat, int *ind, int *cnt, float *out) {
grid_stats_kernel<<<b, optimal_num_threads(n)>>>(b, n, r, r2, r3, coords, ind,
cnt);
avg_voxelize_kernel<<<b, optimal_num_threads(n)>>>(b, c, n, r3, ind, cnt,
feat, out);
CUDA_CHECK_ERRORS();
}
void avg_voxelize_grad(int b, int c, int n, int s, const int *ind,
const int *cnt, const float *grad_y, float *grad_x) {
avg_voxelize_grad_kernel<<<b, optimal_num_threads(n)>>>(b, c, n, s, ind, cnt,
grad_y, grad_x);
CUDA_CHECK_ERRORS();
}

View file

@ -0,0 +1,10 @@
#ifndef _VOX_CUH
#define _VOX_CUH
// CUDA function declarations
void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
const float *feat, int *ind, int *cnt, float *out);
void avg_voxelize_grad(int b, int c, int n, int s, const int *idx,
const int *cnt, const float *grad_y, float *grad_x);
#endif

View file

@ -0,0 +1,15 @@
#ifndef _VOX_HPP
#define _VOX_HPP
#include <torch/torch.h>
#include <vector>
std::vector<at::Tensor> avg_voxelize_forward(const at::Tensor features,
const at::Tensor coords,
const int resolution);
at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
const at::Tensor indices,
const at::Tensor cnt);
#endif

View file

@ -0,0 +1,47 @@
from torch.autograd import Function
import torch
# from modules.functional.backend import _backend
from third_party.pvcnn.functional.backend import _backend
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
__all__ = ['avg_voxelize']
class AvgVoxelization(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, features, coords, resolution):
"""
:param ctx:
:param features: Features of the point cloud, FloatTensor[B, C, N]
:param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N]
:param resolution: Voxel resolution
:return:
Voxelized Features, FloatTensor[B, C, R, R, R]
"""
features = features.contiguous()
coords = coords.int()[:,:3].contiguous()
b, c, _ = features.shape
out, indices, counts = _backend.avg_voxelize_forward(
features, coords, resolution)
ctx.save_for_backward(indices, counts)
return out.view(b, c, resolution, resolution, resolution)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
"""
:param ctx:
:param grad_output: gradient of output, FloatTensor[B, C, R, R, R]
:return:
gradient of inputs, FloatTensor[B, C, N]
"""
b, c = grad_output.shape[:2]
indices, counts = ctx.saved_tensors
grad_features = _backend.avg_voxelize_backward(
grad_output.contiguous().view(b, c, -1), indices, counts)
return grad_features, None, None
avg_voxelize = AvgVoxelization.apply

21
third_party/torchdiffeq/LICENSE vendored Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018 Ricky Tian Qi Chen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

1
third_party/torchdiffeq/README.md vendored Normal file
View file

@ -0,0 +1 @@
adapted from `https://github.com/rtqichen/torchdiffeq/tree/master/torchdiffeq`

View file

@ -0,0 +1,4 @@
from ._impl import odeint
from ._impl import odeint_adjoint
from ._impl import odeint_event
__version__ = "0.2.2"

View file

@ -0,0 +1,2 @@
from .odeint import odeint, odeint_event
from .adjoint import odeint_adjoint

View file

@ -0,0 +1,25 @@
import torch
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
_ADAPTIVE_HEUN_TABLEAU = _ButcherTableau(
alpha=torch.tensor([1.], dtype=torch.float64),
beta=[
torch.tensor([1.], dtype=torch.float64),
],
c_sol=torch.tensor([0.5, 0.5], dtype=torch.float64),
c_error=torch.tensor([
0.5,
-0.5,
], dtype=torch.float64),
)
_AH_C_MID = torch.tensor([
0.5, 0.
], dtype=torch.float64)
class AdaptiveHeunSolver(RKAdaptiveStepsizeODESolver):
order = 2
tableau = _ADAPTIVE_HEUN_TABLEAU
mid = _AH_C_MID

View file

@ -0,0 +1,280 @@
import warnings
import torch
import torch.nn as nn
from .odeint import SOLVERS, odeint
from .misc import _check_inputs, _flat_to_shape
from .misc import _mixed_norm
class OdeintAdjointMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol, adjoint_method,
adjoint_options, t_requires_grad, *adjoint_params):
ctx.shapes = shapes
ctx.func = func
ctx.adjoint_rtol = adjoint_rtol
ctx.adjoint_atol = adjoint_atol
ctx.adjoint_method = adjoint_method
ctx.adjoint_options = adjoint_options
ctx.t_requires_grad = t_requires_grad
ctx.event_mode = event_fn is not None
with torch.no_grad():
ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
if event_fn is None:
y = ans
else:
event_t, y = ans
ctx.event_t = event_t
ctx.save_for_backward(t, y, *adjoint_params)
return ans
@staticmethod
def backward(ctx, *grad_y):
with torch.no_grad():
func = ctx.func
adjoint_rtol = ctx.adjoint_rtol
adjoint_atol = ctx.adjoint_atol
adjoint_method = ctx.adjoint_method
adjoint_options = ctx.adjoint_options
t_requires_grad = ctx.t_requires_grad
t, y, *adjoint_params = ctx.saved_tensors
adjoint_params = tuple(adjoint_params)
# Backprop as if integrating up to event time.
# Does NOT backpropagate through the event time.
event_mode = ctx.event_mode
if event_mode:
event_t = ctx.event_t
_t = t
t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
grad_y = grad_y[1]
else:
grad_y = grad_y[0]
##################################
# Set up initial state #
##################################
# [-1] because y and grad_y are both of shape (len(t), *y0.shape)
aug_state = [torch.zeros((), dtype=y.dtype, device=y.device), y[-1], grad_y[-1]] # vjp_t, y, vjp_y
aug_state.extend([torch.zeros_like(param) for param in adjoint_params]) # vjp_params
##################################
# Set up backward ODE func #
##################################
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.
y = y_aug[1]
adj_y = y_aug[2]
# ignore gradients wrt time and parameters
with torch.enable_grad():
t_ = t.detach()
t = t_.requires_grad_(True)
y = y.detach().requires_grad_(True)
# If using an adaptive solver we don't want to waste time resolving dL/dt unless we need it (which
# doesn't necessarily even exist if there is piecewise structure in time), so turning off gradients
# wrt t here means we won't compute that if we don't need it.
func_eval = func(t if t_requires_grad else t_, y)
# Workaround for PyTorch bug #39784
_t = torch.as_strided(t, (), ()) # noqa
_y = torch.as_strided(y, (), ()) # noqa
_params = tuple(torch.as_strided(param, (), ()) for param in adjoint_params) # noqa
vjp_t, vjp_y, *vjp_params = torch.autograd.grad(
func_eval, (t, y) + adjoint_params, -adj_y,
allow_unused=True, retain_graph=True
)
# autograd.grad returns None if no gradient, set to zero.
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = torch.zeros_like(y) if vjp_y is None else vjp_y
vjp_params = [torch.zeros_like(param) if vjp_param is None else vjp_param
for param, vjp_param in zip(adjoint_params, vjp_params)]
return (vjp_t, func_eval, vjp_y, *vjp_params)
##################################
# Solve adjoint ODE #
##################################
if t_requires_grad:
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
else:
time_vjps = None
for i in range(len(t) - 1, 0, -1):
if t_requires_grad:
# Compute the effect of moving the current time measurement point.
# We don't compute this unless we need to, to save some computation.
func_eval = func(t[i], y[i])
dLd_cur_t = func_eval.reshape(-1).dot(grad_y[i].reshape(-1))
aug_state[0] -= dLd_cur_t
time_vjps[i] = dLd_cur_t
# Run the augmented system backwards in time.
aug_state = odeint(
augmented_dynamics, tuple(aug_state),
t[i - 1:i + 1].flip(0),
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
)
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
aug_state[1] = y[i - 1] # update to use our forward-pass estimate of the state
aug_state[2] += grad_y[i - 1] # update any gradients wrt state at this time point
if t_requires_grad:
time_vjps[0] = aug_state[0]
# Only compute gradient wrt initial time when in event handling mode.
if event_mode and t_requires_grad:
time_vjps = torch.cat([time_vjps[0].reshape(-1), torch.zeros_like(_t[1:])])
adj_y = aug_state[2]
adj_params = aug_state[3:]
return (None, None, adj_y, time_vjps, None, None, None, None, None, None, None, None, None, None, *adj_params)
def odeint_adjoint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None,
adjoint_rtol=None, adjoint_atol=None, adjoint_method=None, adjoint_options=None, adjoint_params=None):
# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if adjoint_params is None and not isinstance(func, nn.Module):
raise ValueError('func must be an instance of nn.Module to specify the adjoint parameters; alternatively they '
'can be specified explicitly via the `adjoint_params` argument. If there are no parameters '
'then it is allowable to set `adjoint_params=()`.')
# Must come before _check_inputs as we don't want to use normalised input (in particular any changes to options)
if adjoint_rtol is None:
adjoint_rtol = rtol
if adjoint_atol is None:
adjoint_atol = atol
if adjoint_method is None:
adjoint_method = method
if adjoint_method != method and options is not None and adjoint_options is None:
raise ValueError("If `adjoint_method != method` then we cannot infer `adjoint_options` from `options`. So as "
"`options` has been passed then `adjoint_options` must be passed as well.")
if adjoint_options is None:
adjoint_options = {k: v for k, v in options.items() if k != "norm"} if options is not None else {}
else:
# Avoid in-place modifying a user-specified dict.
adjoint_options = adjoint_options.copy()
if adjoint_params is None:
adjoint_params = tuple(find_parameters(func))
else:
adjoint_params = tuple(adjoint_params) # in case adjoint_params is a generator.
# Filter params that don't require gradients.
oldlen_ = len(adjoint_params)
adjoint_params = tuple(p for p in adjoint_params if p.requires_grad)
if len(adjoint_params) != oldlen_:
# Some params were excluded.
# Issue a warning if a user-specified norm is specified.
if 'norm' in adjoint_options and callable(adjoint_options['norm']):
warnings.warn("An adjoint parameter was passed without requiring gradient. For efficiency this will be "
"excluded from the adjoint pass, and will not appear as a tensor in the adjoint norm.")
# Convert to flattened state.
shapes, func, y0, t, rtol, atol, method, options, event_fn, decreasing_time = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
# Handle the adjoint norm function.
state_norm = options["norm"]
handle_adjoint_norm_(adjoint_options, shapes, state_norm)
ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,
adjoint_method, adjoint_options, t.requires_grad, *adjoint_params)
if event_fn is None:
solution = ans
else:
event_t, solution = ans
event_t = event_t.to(t)
if decreasing_time:
event_t = -event_t
if shapes is not None:
solution = _flat_to_shape(solution, (len(t),), shapes)
if event_fn is None:
return solution
else:
return event_t, solution
def find_parameters(module):
assert isinstance(module, nn.Module)
# If called within DataParallel, parameters won't appear in module.parameters().
if getattr(module, '_is_replica', False):
def find_tensor_attributes(module):
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v) and v.requires_grad]
return tuples
gen = module._named_members(get_members_fn=find_tensor_attributes)
return [param for _, param in gen]
else:
return list(module.parameters())
def handle_adjoint_norm_(adjoint_options, shapes, state_norm):
"""In-place modifies the adjoint options to choose or wrap the norm function."""
# This is the default adjoint norm on the backward pass: a mixed norm over the tuple of inputs.
def default_adjoint_norm(tensor_tuple):
t, y, adj_y, *adj_params = tensor_tuple
# (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
return max(t.abs(), state_norm(y), state_norm(adj_y), _mixed_norm(adj_params))
if "norm" not in adjoint_options:
# `adjoint_options` was not explicitly specified by the user. Use the default norm.
adjoint_options["norm"] = default_adjoint_norm
else:
# `adjoint_options` was explicitly specified by the user...
try:
adjoint_norm = adjoint_options['norm']
except KeyError:
# ...but they did not specify the norm argument. Back to plan A: use the default norm.
adjoint_options['norm'] = default_adjoint_norm
else:
# ...and they did specify the norm argument.
if adjoint_norm == 'seminorm':
# They told us they want to use seminorms. Slight modification to plan A: use the default norm,
# but ignore the parameter state
def adjoint_seminorm(tensor_tuple):
t, y, adj_y, *adj_params = tensor_tuple
# (If the state is actually a flattened tuple then this will be unpacked again in state_norm.)
return max(t.abs(), state_norm(y), state_norm(adj_y))
adjoint_options['norm'] = adjoint_seminorm
else:
# And they're using their own custom norm.
if shapes is None:
# The state on the forward pass was a tensor, not a tuple. We don't need to do anything, they're
# already going to get given the full adjoint state as (t, y, adj_y, adj_params)
pass # this branch included for clarity
else:
# This is the bit that is tuple/tensor abstraction-breaking, because the odeint machinery
# doesn't know about the tupled nature of the forward state. We need to tell the user's adjoint
# norm about that ourselves.
def _adjoint_norm(tensor_tuple):
t, y, adj_y, *adj_params = tensor_tuple
y = _flat_to_shape(y, (), shapes)
adj_y = _flat_to_shape(adj_y, (), shapes)
return adjoint_norm((t, *y, *adj_y, *adj_params))
adjoint_options['norm'] = _adjoint_norm

View file

@ -0,0 +1,22 @@
import torch
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver
_BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau(
alpha=torch.tensor([1 / 2, 3 / 4, 1.], dtype=torch.float64),
beta=[
torch.tensor([1 / 2], dtype=torch.float64),
torch.tensor([0., 3 / 4], dtype=torch.float64),
torch.tensor([2 / 9, 1 / 3, 4 / 9], dtype=torch.float64)
],
c_sol=torch.tensor([2 / 9, 1 / 3, 4 / 9, 0.], dtype=torch.float64),
c_error=torch.tensor([2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8], dtype=torch.float64),
)
_BS_C_MID = torch.tensor([0., 0.5, 0., 0.], dtype=torch.float64)
class Bosh3Solver(RKAdaptiveStepsizeODESolver):
order = 3
tableau = _BOGACKI_SHAMPINE_TABLEAU
mid = _BS_C_MID

Some files were not shown because too many files have changed in this diff Show more