From 03eb2e48fd436e54edb5e2a8ec494884edca4163 Mon Sep 17 00:00:00 2001 From: Xu Ma Date: Mon, 4 Oct 2021 03:25:18 -0400 Subject: [PATCH] update --- pointnet2_ops_lib/MANIFEST.in | 1 + pointnet2_ops_lib/pointnet2_ops/__init__.py | 3 + .../_ext-src/include/ball_query.h | 5 + .../_ext-src/include/cuda_utils.h | 41 ++ .../_ext-src/include/group_points.h | 5 + .../_ext-src/include/interpolate.h | 10 + .../pointnet2_ops/_ext-src/include/sampling.h | 6 + .../pointnet2_ops/_ext-src/include/utils.h | 25 ++ .../pointnet2_ops/_ext-src/src/ball_query.cpp | 32 ++ .../_ext-src/src/ball_query_gpu.cu | 54 +++ .../pointnet2_ops/_ext-src/src/bindings.cpp | 19 + .../_ext-src/src/group_points.cpp | 62 +++ .../_ext-src/src/group_points_gpu.cu | 75 ++++ .../_ext-src/src/interpolate.cpp | 99 +++++ .../_ext-src/src/interpolate_gpu.cu | 154 +++++++ .../pointnet2_ops/_ext-src/src/sampling.cpp | 87 ++++ .../_ext-src/src/sampling_gpu.cu | 229 +++++++++++ pointnet2_ops_lib/pointnet2_ops/_version.py | 1 + .../pointnet2_ops/pointnet2_modules.py | 209 ++++++++++ .../pointnet2_ops/pointnet2_utils.py | 379 ++++++++++++++++++ pointnet2_ops_lib/setup.py | 39 ++ 21 files changed, 1535 insertions(+) create mode 100644 pointnet2_ops_lib/MANIFEST.in create mode 100644 pointnet2_ops_lib/pointnet2_ops/__init__.py create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp create mode 100644 pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu create mode 100644 pointnet2_ops_lib/pointnet2_ops/_version.py create mode 100644 pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py create mode 100644 pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py create mode 100644 pointnet2_ops_lib/setup.py diff --git a/pointnet2_ops_lib/MANIFEST.in b/pointnet2_ops_lib/MANIFEST.in new file mode 100644 index 0000000..a4eb5de --- /dev/null +++ b/pointnet2_ops_lib/MANIFEST.in @@ -0,0 +1 @@ +graft pointnet2_ops/_ext-src diff --git a/pointnet2_ops_lib/pointnet2_ops/__init__.py b/pointnet2_ops_lib/pointnet2_ops/__init__.py new file mode 100644 index 0000000..5fd361f --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/__init__.py @@ -0,0 +1,3 @@ +import pointnet2_ops.pointnet2_modules +import pointnet2_ops.pointnet2_utils +from pointnet2_ops._version import __version__ diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h new file mode 100644 index 0000000..1bbc638 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h @@ -0,0 +1,5 @@ +#pragma once +#include + +at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, + const int nsample); diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h new file mode 100644 index 0000000..0fd5b6e --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h @@ -0,0 +1,41 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define TOTAL_THREADS 512 + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = + max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#endif diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h new file mode 100644 index 0000000..ad20cda --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h @@ -0,0 +1,5 @@ +#pragma once +#include + +at::Tensor group_points(at::Tensor points, at::Tensor idx); +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h new file mode 100644 index 0000000..26b3464 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +std::vector three_nn(at::Tensor unknowns, at::Tensor knows); +at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, + at::Tensor weight); +at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, + at::Tensor weight, const int m); diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h new file mode 100644 index 0000000..d795271 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h @@ -0,0 +1,6 @@ +#pragma once +#include + +at::Tensor gather_points(at::Tensor points, at::Tensor idx); +at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); +at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h new file mode 100644 index 0000000..5f080ed --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h @@ -0,0 +1,25 @@ +#pragma once +#include +#include + +#define CHECK_CUDA(x) \ + do { \ + AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ + } while (0) + +#define CHECK_CONTIGUOUS(x) \ + do { \ + AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ + } while (0) + +#define CHECK_IS_INT(x) \ + do { \ + AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor"); \ + } while (0) + +#define CHECK_IS_FLOAT(x) \ + do { \ + AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor"); \ + } while (0) diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp new file mode 100644 index 0000000..b1797c1 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp @@ -0,0 +1,32 @@ +#include "ball_query.h" +#include "utils.h" + +void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, + int nsample, const float *new_xyz, + const float *xyz, int *idx); + +at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, + const int nsample) { + CHECK_CONTIGUOUS(new_xyz); + CHECK_CONTIGUOUS(xyz); + CHECK_IS_FLOAT(new_xyz); + CHECK_IS_FLOAT(xyz); + + if (new_xyz.is_cuda()) { + CHECK_CUDA(xyz); + } + + at::Tensor idx = + torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, + at::device(new_xyz.device()).dtype(at::ScalarType::Int)); + + if (new_xyz.is_cuda()) { + query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), + radius, nsample, new_xyz.data_ptr(), + xyz.data_ptr(), idx.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return idx; +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu new file mode 100644 index 0000000..559aef9 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu @@ -0,0 +1,54 @@ +#include +#include +#include + +#include "cuda_utils.h" + +// input: new_xyz(b, m, 3) xyz(b, n, 3) +// output: idx(b, m, nsample) +__global__ void query_ball_point_kernel(int b, int n, int m, float radius, + int nsample, + const float *__restrict__ new_xyz, + const float *__restrict__ xyz, + int *__restrict__ idx) { + int batch_index = blockIdx.x; + xyz += batch_index * n * 3; + new_xyz += batch_index * m * 3; + idx += m * nsample * batch_index; + + int index = threadIdx.x; + int stride = blockDim.x; + + float radius2 = radius * radius; + for (int j = index; j < m; j += stride) { + float new_x = new_xyz[j * 3 + 0]; + float new_y = new_xyz[j * 3 + 1]; + float new_z = new_xyz[j * 3 + 2]; + for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + + (new_z - z) * (new_z - z); + if (d2 < radius2) { + if (cnt == 0) { + for (int l = 0; l < nsample; ++l) { + idx[j * nsample + l] = k; + } + } + idx[j * nsample + cnt] = k; + ++cnt; + } + } + } +} + +void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, + int nsample, const float *new_xyz, + const float *xyz, int *idx) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + query_ball_point_kernel<<>>( + b, n, m, radius, nsample, new_xyz, xyz, idx); + + CUDA_CHECK_ERRORS(); +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp new file mode 100644 index 0000000..d1916ce --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp @@ -0,0 +1,19 @@ +#include "ball_query.h" +#include "group_points.h" +#include "interpolate.h" +#include "sampling.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gather_points", &gather_points); + m.def("gather_points_grad", &gather_points_grad); + m.def("furthest_point_sampling", &furthest_point_sampling); + + m.def("three_nn", &three_nn); + m.def("three_interpolate", &three_interpolate); + m.def("three_interpolate_grad", &three_interpolate_grad); + + m.def("ball_query", &ball_query); + + m.def("group_points", &group_points); + m.def("group_points_grad", &group_points_grad); +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp new file mode 100644 index 0000000..285a4bd --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp @@ -0,0 +1,62 @@ +#include "group_points.h" +#include "utils.h" + +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, + float *out); + +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + int nsample, const float *grad_out, + const int *idx, float *grad_points); + +at::Tensor group_points(at::Tensor points, at::Tensor idx) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), + idx.size(1), idx.size(2), + points.data_ptr(), idx.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} + +at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), n}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + group_points_grad_kernel_wrapper( + grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), + grad_out.data_ptr(), idx.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu new file mode 100644 index 0000000..57c2b1b --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu @@ -0,0 +1,75 @@ +#include +#include + +#include "cuda_utils.h" + +// input: points(b, c, n) idx(b, npoints, nsample) +// output: out(b, c, npoints, nsample) +__global__ void group_points_kernel(int b, int c, int n, int npoints, + int nsample, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + int batch_index = blockIdx.x; + points += batch_index * n * c; + idx += batch_index * npoints * nsample; + out += batch_index * npoints * nsample * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; + for (int k = 0; k < nsample; ++k) { + int ii = idx[j * nsample + k]; + out[(l * npoints + j) * nsample + k] = points[l * n + ii]; + } + } +} + +void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, + float *out) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + group_points_kernel<<>>( + b, c, n, npoints, nsample, points, idx, out); + + CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) +// output: grad_points(b, c, n) +__global__ void group_points_grad_kernel(int b, int c, int n, int npoints, + int nsample, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + int batch_index = blockIdx.x; + grad_out += batch_index * npoints * nsample * c; + idx += batch_index * npoints * nsample; + grad_points += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * npoints; i += stride) { + const int l = i / npoints; + const int j = i % npoints; + for (int k = 0; k < nsample; ++k) { + int ii = idx[j * nsample + k]; + atomicAdd(grad_points + l * n + ii, + grad_out[(l * npoints + j) * nsample + k]); + } + } +} + +void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + int nsample, const float *grad_out, + const int *idx, float *grad_points) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + group_points_grad_kernel<<>>( + b, c, n, npoints, nsample, grad_out, idx, grad_points); + + CUDA_CHECK_ERRORS(); +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp new file mode 100644 index 0000000..cdee31c --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp @@ -0,0 +1,99 @@ +#include "interpolate.h" +#include "utils.h" + +void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx); +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, + const float *points, const int *idx, + const float *weight, float *out); +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, + const float *grad_out, + const int *idx, const float *weight, + float *grad_points); + +std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { + CHECK_CONTIGUOUS(unknowns); + CHECK_CONTIGUOUS(knows); + CHECK_IS_FLOAT(unknowns); + CHECK_IS_FLOAT(knows); + + if (unknowns.is_cuda()) { + CHECK_CUDA(knows); + } + + at::Tensor idx = + torch::zeros({unknowns.size(0), unknowns.size(1), 3}, + at::device(unknowns.device()).dtype(at::ScalarType::Int)); + at::Tensor dist2 = + torch::zeros({unknowns.size(0), unknowns.size(1), 3}, + at::device(unknowns.device()).dtype(at::ScalarType::Float)); + + if (unknowns.is_cuda()) { + three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), + unknowns.data_ptr(), knows.data_ptr(), + dist2.data_ptr(), idx.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return {dist2, idx}; +} + +at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, + at::Tensor weight) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_CONTIGUOUS(weight); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + CHECK_IS_FLOAT(weight); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + CHECK_CUDA(weight); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + three_interpolate_kernel_wrapper( + points.size(0), points.size(1), points.size(2), idx.size(1), + points.data_ptr(), idx.data_ptr(), weight.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} +at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, + at::Tensor weight, const int m) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_CONTIGUOUS(weight); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + CHECK_IS_FLOAT(weight); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + CHECK_CUDA(weight); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), m}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + three_interpolate_grad_kernel_wrapper( + grad_out.size(0), grad_out.size(1), grad_out.size(2), m, + grad_out.data_ptr(), idx.data_ptr(), + weight.data_ptr(), output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu new file mode 100644 index 0000000..81c5548 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu @@ -0,0 +1,154 @@ +#include +#include +#include + +#include "cuda_utils.h" + +// input: unknown(b, n, 3) known(b, m, 3) +// output: dist2(b, n, 3), idx(b, n, 3) +__global__ void three_nn_kernel(int b, int n, int m, + const float *__restrict__ unknown, + const float *__restrict__ known, + float *__restrict__ dist2, + int *__restrict__ idx) { + int batch_index = blockIdx.x; + unknown += batch_index * n * 3; + known += batch_index * m * 3; + dist2 += batch_index * n * 3; + idx += batch_index * n * 3; + + int index = threadIdx.x; + int stride = blockDim.x; + for (int j = index; j < n; j += stride) { + float ux = unknown[j * 3 + 0]; + float uy = unknown[j * 3 + 1]; + float uz = unknown[j * 3 + 2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; + besti3 = besti2; + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + } else if (d < best2) { + best3 = best2; + besti3 = besti2; + best2 = d; + besti2 = k; + } else if (d < best3) { + best3 = d; + besti3 = k; + } + } + dist2[j * 3 + 0] = best1; + dist2[j * 3 + 1] = best2; + dist2[j * 3 + 2] = best3; + + idx[j * 3 + 0] = besti1; + idx[j * 3 + 1] = besti2; + idx[j * 3 + 2] = besti3; + } +} + +void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_nn_kernel<<>>(b, n, m, unknown, known, + dist2, idx); + + CUDA_CHECK_ERRORS(); +} + +// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) +// output: out(b, c, n) +__global__ void three_interpolate_kernel(int b, int c, int m, int n, + const float *__restrict__ points, + const int *__restrict__ idx, + const float *__restrict__ weight, + float *__restrict__ out) { + int batch_index = blockIdx.x; + points += batch_index * m * c; + + idx += batch_index * n * 3; + weight += batch_index * n * 3; + + out += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weight[j * 3 + 0]; + float w2 = weight[j * 3 + 1]; + float w3 = weight[j * 3 + 2]; + + int i1 = idx[j * 3 + 0]; + int i2 = idx[j * 3 + 1]; + int i3 = idx[j * 3 + 2]; + + out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + + points[l * m + i3] * w3; + } +} + +void three_interpolate_kernel_wrapper(int b, int c, int m, int n, + const float *points, const int *idx, + const float *weight, float *out) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_kernel<<>>( + b, c, m, n, points, idx, weight, out); + + CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) +// output: grad_points(b, c, m) + +__global__ void three_interpolate_grad_kernel( + int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, const float *__restrict__ weight, + float *__restrict__ grad_points) { + int batch_index = blockIdx.x; + grad_out += batch_index * n * c; + idx += batch_index * n * 3; + weight += batch_index * n * 3; + grad_points += batch_index * m * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weight[j * 3 + 0]; + float w2 = weight[j * 3 + 1]; + float w3 = weight[j * 3 + 2]; + + int i1 = idx[j * 3 + 0]; + int i2 = idx[j * 3 + 1]; + int i3 = idx[j * 3 + 2]; + + atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); + atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); + atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); + } +} + +void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, + const float *grad_out, + const int *idx, const float *weight, + float *grad_points) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_grad_kernel<<>>( + b, c, n, m, grad_out, idx, weight, grad_points); + + CUDA_CHECK_ERRORS(); +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp new file mode 100644 index 0000000..ddbdc11 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp @@ -0,0 +1,87 @@ +#include "sampling.h" +#include "utils.h" + +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, + const float *points, const int *idx, + float *out); +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points); + +void furthest_point_sampling_kernel_wrapper(int b, int n, int m, + const float *dataset, float *temp, + int *idxs); + +at::Tensor gather_points(at::Tensor points, at::Tensor idx) { + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(points); + CHECK_IS_INT(idx); + + if (points.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({points.size(0), points.size(1), idx.size(1)}, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), + idx.size(1), points.data_ptr(), + idx.data_ptr(), output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} + +at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, + const int n) { + CHECK_CONTIGUOUS(grad_out); + CHECK_CONTIGUOUS(idx); + CHECK_IS_FLOAT(grad_out); + CHECK_IS_INT(idx); + + if (grad_out.is_cuda()) { + CHECK_CUDA(idx); + } + + at::Tensor output = + torch::zeros({grad_out.size(0), grad_out.size(1), n}, + at::device(grad_out.device()).dtype(at::ScalarType::Float)); + + if (grad_out.is_cuda()) { + gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, + idx.size(1), grad_out.data_ptr(), + idx.data_ptr(), + output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} +at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(points); + + at::Tensor output = + torch::zeros({points.size(0), nsamples}, + at::device(points.device()).dtype(at::ScalarType::Int)); + + at::Tensor tmp = + torch::full({points.size(0), points.size(1)}, 1e10, + at::device(points.device()).dtype(at::ScalarType::Float)); + + if (points.is_cuda()) { + furthest_point_sampling_kernel_wrapper( + points.size(0), points.size(1), nsamples, points.data_ptr(), + tmp.data_ptr(), output.data_ptr()); + } else { + AT_ASSERT(false, "CPU not supported"); + } + + return output; +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu new file mode 100644 index 0000000..fc573f0 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu @@ -0,0 +1,229 @@ +#include +#include + +#include "cuda_utils.h" + +// input: points(b, c, n) idx(b, m) +// output: out(b, c, m) +__global__ void gather_points_kernel(int b, int c, int n, int m, + const float *__restrict__ points, + const int *__restrict__ idx, + float *__restrict__ out) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; + } + } + } +} + +void gather_points_kernel_wrapper(int b, int c, int n, int npoints, + const float *points, const int *idx, + float *out) { + gather_points_kernel<<>>(b, c, n, npoints, + points, idx, out); + + CUDA_CHECK_ERRORS(); +} + +// input: grad_out(b, c, m) idx(b, m) +// output: grad_points(b, c, n) +__global__ void gather_points_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_out, + const int *__restrict__ idx, + float *__restrict__ grad_points) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int l = blockIdx.y; l < c; l += gridDim.y) { + for (int j = threadIdx.x; j < m; j += blockDim.x) { + int a = idx[i * m + j]; + atomicAdd(grad_points + (i * c + l) * n + a, + grad_out[(i * c + l) * m + j]); + } + } + } +} + +void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, + float *grad_points) { + gather_points_grad_kernel<<>>( + b, c, n, npoints, grad_out, idx, grad_points); + + CUDA_CHECK_ERRORS(); +} + +__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, + int idx1, int idx2) { + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +// Input dataset: (b, n, 3), tmp: (b, n) +// Ouput idxs (b, m) +template +__global__ void furthest_point_sampling_kernel( + int b, int n, int m, const float *__restrict__ dataset, + float *__restrict__ temp, int *__restrict__ idxs) { + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + if (mag <= 1e-3) continue; + + float d = + (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) idxs[j] = old; + } +} + +void furthest_point_sampling_kernel_wrapper(int b, int n, int m, + const float *dataset, float *temp, + int *idxs) { + unsigned int n_threads = opt_n_threads(n); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (n_threads) { + case 512: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 256: + furthest_point_sampling_kernel<256> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 128: + furthest_point_sampling_kernel<128> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 64: + furthest_point_sampling_kernel<64> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 32: + furthest_point_sampling_kernel<32> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 16: + furthest_point_sampling_kernel<16> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 8: + furthest_point_sampling_kernel<8> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 4: + furthest_point_sampling_kernel<4> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 2: + furthest_point_sampling_kernel<2> + <<>>(b, n, m, dataset, temp, idxs); + break; + case 1: + furthest_point_sampling_kernel<1> + <<>>(b, n, m, dataset, temp, idxs); + break; + default: + furthest_point_sampling_kernel<512> + <<>>(b, n, m, dataset, temp, idxs); + } + + CUDA_CHECK_ERRORS(); +} diff --git a/pointnet2_ops_lib/pointnet2_ops/_version.py b/pointnet2_ops_lib/pointnet2_ops/_version.py new file mode 100644 index 0000000..528787c --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/_version.py @@ -0,0 +1 @@ +__version__ = "3.0.0" diff --git a/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py b/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py new file mode 100644 index 0000000..a0ad4f6 --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py @@ -0,0 +1,209 @@ +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pointnet2_ops import pointnet2_utils + + +def build_shared_mlp(mlp_spec: List[int], bn: bool = True): + layers = [] + for i in range(1, len(mlp_spec)): + layers.append( + nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) + ) + if bn: + layers.append(nn.BatchNorm2d(mlp_spec[i])) + layers.append(nn.ReLU(True)) + + return nn.Sequential(*layers) + + +class _PointnetSAModuleBase(nn.Module): + def __init__(self): + super(_PointnetSAModuleBase, self).__init__() + self.npoint = None + self.groupers = None + self.mlps = None + + def forward( + self, xyz: torch.Tensor, features: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor of the xyz coordinates of the features + features : torch.Tensor + (B, C, N) tensor of the descriptors of the the features + + Returns + ------- + new_xyz : torch.Tensor + (B, npoint, 3) tensor of the new features' xyz + new_features : torch.Tensor + (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors + """ + + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + new_xyz = ( + pointnet2_utils.gather_operation( + xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ) + .transpose(1, 2) + .contiguous() + if self.npoint is not None + else None + ) + + for i in range(len(self.groupers)): + new_features = self.groupers[i]( + xyz, new_xyz, features + ) # (B, C, npoint, nsample) + + new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + r"""Pointnet set abstrction layer with multiscale grouping + + Parameters + ---------- + npoint : int + Number of features + radii : list of float32 + list of radii to group with + nsamples : list of int32 + Number of samples in each ball query + mlps : list of list of int32 + Spec of the pointnet before the global max_pool for each scale + bn : bool + Use batchnorm + """ + + def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): + # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None + super(PointnetSAModuleMSG, self).__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) + if npoint is not None + else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(build_shared_mlp(mlp_spec, bn)) + + +class PointnetSAModule(PointnetSAModuleMSG): + r"""Pointnet set abstrction layer + + Parameters + ---------- + npoint : int + Number of features + radius : float + Radius of ball + nsample : int + Number of samples in the ball query + mlp : list + Spec of the pointnet before the global max_pool + bn : bool + Use batchnorm + """ + + def __init__( + self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True + ): + # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None + super(PointnetSAModule, self).__init__( + mlps=[mlp], + npoint=npoint, + radii=[radius], + nsamples=[nsample], + bn=bn, + use_xyz=use_xyz, + ) + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another + + Parameters + ---------- + mlp : list + Pointnet module parameters + bn : bool + Use batchnorm + """ + + def __init__(self, mlp, bn=True): + # type: (PointnetFPModule, List[int], bool) -> None + super(PointnetFPModule, self).__init__() + self.mlp = build_shared_mlp(mlp, bn=bn) + + def forward(self, unknown, known, unknow_feats, known_feats): + # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + Parameters + ---------- + unknown : torch.Tensor + (B, n, 3) tensor of the xyz positions of the unknown features + known : torch.Tensor + (B, m, 3) tensor of the xyz positions of the known features + unknow_feats : torch.Tensor + (B, C1, n) tensor of the features to be propigated to + known_feats : torch.Tensor + (B, C2, m) tensor of features to be propigated + + Returns + ------- + new_features : torch.Tensor + (B, mlp[-1], n) tensor of the features of the unknown features + """ + + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate( + known_feats, idx, weight + ) + else: + interpolated_feats = known_feats.expand( + *(known_feats.size()[0:2] + [unknown.size(1)]) + ) + + if unknow_feats is not None: + new_features = torch.cat( + [interpolated_feats, unknow_feats], dim=1 + ) # (B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) diff --git a/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py b/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py new file mode 100644 index 0000000..150fccc --- /dev/null +++ b/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py @@ -0,0 +1,379 @@ +import torch +import torch.nn as nn +import warnings +from torch.autograd import Function +from typing import * + +try: + import pointnet2_ops._ext as _ext +except ImportError: + from torch.utils.cpp_extension import load + import glob + import os.path as osp + import os + + warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") + + _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") + _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( + osp.join(_ext_src_root, "src", "*.cu") + ) + _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) + + os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" + _ext = load( + "_ext", + sources=_ext_sources, + extra_include_paths=[osp.join(_ext_src_root, "include")], + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], + with_cuda=True, + ) + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz, npoint): + # type: (Any, torch.Tensor, int) -> torch.Tensor + r""" + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + + Parameters + ---------- + xyz : torch.Tensor + (B, N, 3) tensor where N > npoint + npoint : int32 + number of features in the sampled set + + Returns + ------- + torch.Tensor + (B, npoint) tensor containing the set + """ + out = _ext.furthest_point_sampling(xyz, npoint) + + ctx.mark_non_differentiable(out) + + return out + + @staticmethod + def backward(ctx, grad_out): + return () + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + @staticmethod + def forward(ctx, features, idx): + # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + features : torch.Tensor + (B, C, N) tensor + + idx : torch.Tensor + (B, npoint) tensor of the features to gather + + Returns + ------- + torch.Tensor + (B, C, npoint) tensor + """ + + ctx.save_for_backward(idx, features) + + return _ext.gather_points(features, idx) + + @staticmethod + def backward(ctx, grad_out): + idx, features = ctx.saved_tensors + N = features.size(2) + + grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + @staticmethod + def forward(ctx, unknown, known): + # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + r""" + Find the three nearest neighbors of unknown in known + Parameters + ---------- + unknown : torch.Tensor + (B, n, 3) tensor of known features + known : torch.Tensor + (B, m, 3) tensor of unknown features + + Returns + ------- + dist : torch.Tensor + (B, n, 3) l2 distance to the three nearest neighbors + idx : torch.Tensor + (B, n, 3) index of 3 nearest neighbors + """ + dist2, idx = _ext.three_nn(unknown, known) + dist = torch.sqrt(dist2) + + ctx.mark_non_differentiable(dist, idx) + + return dist, idx + + @staticmethod + def backward(ctx, grad_dist, grad_idx): + return () + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + @staticmethod + def forward(ctx, features, idx, weight): + # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor + r""" + Performs weight linear interpolation on 3 features + Parameters + ---------- + features : torch.Tensor + (B, c, m) Features descriptors to be interpolated from + idx : torch.Tensor + (B, n, 3) three nearest neighbors of the target features in features + weight : torch.Tensor + (B, n, 3) weights + + Returns + ------- + torch.Tensor + (B, c, n) tensor of the interpolated features + """ + ctx.save_for_backward(idx, weight, features) + + return _ext.three_interpolate(features, idx, weight) + + @staticmethod + def backward(ctx, grad_out): + # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + r""" + Parameters + ---------- + grad_out : torch.Tensor + (B, c, n) tensor with gradients of ouputs + + Returns + ------- + grad_features : torch.Tensor + (B, c, m) tensor with gradients of features + + None + + None + """ + idx, weight, features = ctx.saved_tensors + m = features.size(2) + + grad_features = _ext.three_interpolate_grad( + grad_out.contiguous(), idx, weight, m + ) + + return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + @staticmethod + def forward(ctx, features, idx): + # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + features : torch.Tensor + (B, C, N) tensor of features to group + idx : torch.Tensor + (B, npoint, nsample) tensor containing the indicies of features to group with + + Returns + ------- + torch.Tensor + (B, C, npoint, nsample) tensor + """ + ctx.save_for_backward(idx, features) + + return _ext.group_points(features, idx) + + @staticmethod + def backward(ctx, grad_out): + # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] + r""" + + Parameters + ---------- + grad_out : torch.Tensor + (B, C, npoint, nsample) tensor of the gradients of the output from forward + + Returns + ------- + torch.Tensor + (B, C, N) gradient of the features + None + """ + idx, features = ctx.saved_tensors + N = features.size(2) + + grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) + + return grad_features, torch.zeros_like(idx) + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + @staticmethod + def forward(ctx, radius, nsample, xyz, new_xyz): + # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" + + Parameters + ---------- + radius : float + radius of the balls + nsample : int + maximum number of features in the balls + xyz : torch.Tensor + (B, N, 3) xyz coordinates of the features + new_xyz : torch.Tensor + (B, npoint, 3) centers of the ball query + + Returns + ------- + torch.Tensor + (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + output = _ext.ball_query(new_xyz, xyz, radius, nsample) + + ctx.mark_non_differentiable(output) + + return output + + @staticmethod + def backward(ctx, grad_out): + return () + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + r""" + Groups with a ball query of radius + + Parameters + --------- + radius : float32 + Radius of ball + nsample : int32 + Maximum number of features to gather in the ball + """ + + def __init__(self, radius, nsample, use_xyz=True): + # type: (QueryAndGroup, float, int, bool) -> None + super(QueryAndGroup, self).__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + + def forward(self, xyz, new_xyz, features=None): + # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] + r""" + Parameters + ---------- + xyz : torch.Tensor + xyz coordinates of the features (B, N, 3) + new_xyz : torch.Tensor + centriods (B, npoint, 3) + features : torch.Tensor + Descriptors of the features (B, C, N) + + Returns + ------- + new_features : torch.Tensor + (B, 3 + C, npoint, nsample) tensor + """ + + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], dim=1 + ) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert ( + self.use_xyz + ), "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + return new_features + + +class GroupAll(nn.Module): + r""" + Groups all features + + Parameters + --------- + """ + + def __init__(self, use_xyz=True): + # type: (GroupAll, bool) -> None + super(GroupAll, self).__init__() + self.use_xyz = use_xyz + + def forward(self, xyz, new_xyz, features=None): + # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] + r""" + Parameters + ---------- + xyz : torch.Tensor + xyz coordinates of the features (B, N, 3) + new_xyz : torch.Tensor + Ignored + features : torch.Tensor + Descriptors of the features (B, C, N) + + Returns + ------- + new_features : torch.Tensor + (B, C + 3, 1, N) tensor + """ + + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat( + [grouped_xyz, grouped_features], dim=1 + ) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features diff --git a/pointnet2_ops_lib/setup.py b/pointnet2_ops_lib/setup.py new file mode 100644 index 0000000..faf7154 --- /dev/null +++ b/pointnet2_ops_lib/setup.py @@ -0,0 +1,39 @@ +import glob +import os +import os.path as osp + +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +this_dir = osp.dirname(osp.abspath(__file__)) +_ext_src_root = osp.join("pointnet2_ops", "_ext-src") +_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( + osp.join(_ext_src_root, "src", "*.cu") +) +_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) + +requirements = ["torch>=1.4"] + +exec(open(osp.join("pointnet2_ops", "_version.py")).read()) + +os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" +setup( + name="pointnet2_ops", + version=__version__, + author="Erik Wijmans", + packages=find_packages(), + install_requires=requirements, + ext_modules=[ + CUDAExtension( + name="pointnet2_ops._ext", + sources=_ext_sources, + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": ["-O3", "-Xfatbin", "-compress-all"], + }, + include_dirs=[osp.join(this_dir, _ext_src_root, "include")], + ) + ], + cmdclass={"build_ext": BuildExtension}, + include_package_data=True, +)