diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu index e1f119d00..62e6690a0 100644 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.cu @@ -212,7 +212,7 @@ __global__ void NearestNeighborKernelD3( } } -at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) { +at::Tensor NearestNeighborIdxCuda(const at::Tensor& p1, const at::Tensor& p2) { const auto N = p1.size(0); const auto P1 = p1.size(1); const auto P2 = p2.size(1); diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h index 51c7e72e2..8b0fa8676 100644 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h @@ -20,13 +20,15 @@ // // CPU implementation. -at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2); +at::Tensor NearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2); // Cuda implementation. -at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2); +#ifdef WITH_CUDA +at::Tensor NearestNeighborIdxCuda(const at::Tensor& p1, const at::Tensor& p2); +#endif // Implementation which is exposed. -at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) { +at::Tensor NearestNeighborIdx(const at::Tensor& p1, const at::Tensor& p2) { if (p1.type().is_cuda() && p2.type().is_cuda()) { #ifdef WITH_CUDA CHECK_CONTIGUOUS_CUDA(p1); diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points_cpu.cpp b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points_cpu.cpp new file mode 100644 index 000000000..c59f2be8a --- /dev/null +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points_cpu.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include + +template +void NearestNeighborIdxCpuKernel( + const at::Tensor& p1, + const at::Tensor& p2, + at::Tensor& out, + const size_t N, + const size_t D, + const size_t P1, + const size_t P2) { + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto out_a = out.accessor(); + + for (size_t n = 0; n < N; ++n) { + for (size_t i1 = 0; i1 < P1; ++i1) { + scalar_t min_dist = -1; + int64_t min_idx = -1; + for (int64_t i2 = 0; i2 < P2; ++i2) { + scalar_t dist = 0; + for (size_t d = 0; d < D; ++d) { + scalar_t diff = p1_a[n][i1][d] - p2_a[n][i2][d]; + dist += diff * diff; + } + if (min_dist == -1 || dist < min_dist) { + min_dist = dist; + min_idx = i2; + } + } + out_a[n][i1] = min_idx; + } + } +} + +at::Tensor NearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2) { + const size_t N = p1.size(0); + const size_t P1 = p1.size(1); + const size_t D = p1.size(2); + const size_t P2 = p2.size(1); + + auto long_opts = p1.options().dtype(torch::kInt64); + torch::Tensor out = torch::empty({N, P1}, long_opts); + + AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_idx_cpu", [&] { + NearestNeighborIdxCpuKernel( + p1, + p2, + out, + N, + D, + P1, + P2 + ); + }); + + return out; +} diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp deleted file mode 100644 index 3dd373b90..000000000 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbors_points_cpu.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -#include - -at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2) { - const int N = p1.size(0); - const int P1 = p1.size(1); - const int D = p1.size(2); - const int P2 = p2.size(1); - - auto long_opts = p1.options().dtype(torch::kInt64); - torch::Tensor out = torch::empty({N, P1}, long_opts); - - auto p1_a = p1.accessor(); - auto p2_a = p2.accessor(); - auto out_a = out.accessor(); - - for (int n = 0; n < N; ++n) { - for (int i1 = 0; i1 < P1; ++i1) { - // TODO: support other floating-point types? - float min_dist = -1; - int64_t min_idx = -1; - for (int i2 = 0; i2 < P2; ++i2) { - float dist = 0; - for (int d = 0; d < D; ++d) { - float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; - dist += diff * diff; - } - if (min_dist == -1 || dist < min_dist) { - min_dist = dist; - min_idx = i2; - } - } - out_a[n][i1] = min_idx; - } - } - return out; -}