Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ __global__ void NearestNeighborKernelD3(
}
}

at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
at::Tensor NearestNeighborIdxCuda(const at::Tensor& p1, const at::Tensor& p2) {
const auto N = p1.size(0);
const auto P1 = p1.size(1);
const auto P2 = p2.size(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
//

// CPU implementation.
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2);
at::Tensor NearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2);

// Cuda implementation.
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2);
#ifdef WITH_CUDA
at::Tensor NearestNeighborIdxCuda(const at::Tensor& p1, const at::Tensor& p2);
#endif

// Implementation which is exposed.
at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {
at::Tensor NearestNeighborIdx(const at::Tensor& p1, const at::Tensor& p2) {
if (p1.type().is_cuda() && p2.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(p1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

#include <torch/extension.h>
#include <ATen/ATen.h>

template <typename scalar_t>
void NearestNeighborIdxCpuKernel(
const at::Tensor& p1,
const at::Tensor& p2,
at::Tensor& out,
const size_t N,
const size_t D,
const size_t P1,
const size_t P2) {

auto p1_a = p1.accessor<scalar_t, 3>();
auto p2_a = p2.accessor<scalar_t, 3>();
auto out_a = out.accessor<int64_t, 2>();

for (size_t n = 0; n < N; ++n) {
for (size_t i1 = 0; i1 < P1; ++i1) {
scalar_t min_dist = -1;
int64_t min_idx = -1;
for (int64_t i2 = 0; i2 < P2; ++i2) {
scalar_t dist = 0;
for (size_t d = 0; d < D; ++d) {
scalar_t diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
}
if (min_dist == -1 || dist < min_dist) {
min_dist = dist;
min_idx = i2;
}
}
out_a[n][i1] = min_idx;
}
}
}

at::Tensor NearestNeighborIdxCpu(const at::Tensor& p1, const at::Tensor& p2) {
const size_t N = p1.size(0);
const size_t P1 = p1.size(1);
const size_t D = p1.size(2);
const size_t P2 = p2.size(1);

auto long_opts = p1.options().dtype(torch::kInt64);
torch::Tensor out = torch::empty({N, P1}, long_opts);

AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_idx_cpu", [&] {
NearestNeighborIdxCpuKernel<scalar_t>(
p1,
p2,
out,
N,
D,
P1,
P2
);
});

return out;
}

This file was deleted.