|
| 1 | +/* |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | +#include <c10/cuda/CUDAGuard.h> |
| 12 | +#include <math.h> |
| 13 | +#include <stdio.h> |
| 14 | +#include <stdlib.h> |
| 15 | +#include "utils/pytorch3d_cutils.h" |
| 16 | + |
| 17 | +// A chunk of work is blocksize-many points of P1. |
| 18 | +// The number of potential chunks to do is N*(1+(P1-1)/blocksize) |
| 19 | +// call (1+(P1-1)/blocksize) chunks_per_cloud |
| 20 | +// These chunks are divided among the gridSize-many blocks. |
| 21 | +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . |
| 22 | +// In chunk i, we work on cloud i/chunks_per_cloud on points starting from |
| 23 | +// blocksize*(i%chunks_per_cloud). |
| 24 | + |
| 25 | +template <typename scalar_t> |
| 26 | +__global__ void BallQueryKernel( |
| 27 | + const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p1, |
| 28 | + const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> p2, |
| 29 | + const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> |
| 30 | + lengths1, |
| 31 | + const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> |
| 32 | + lengths2, |
| 33 | + at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs, |
| 34 | + at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists, |
| 35 | + const int64_t K, |
| 36 | + const float radius2) { |
| 37 | + const int64_t N = p1.size(0); |
| 38 | + const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x); |
| 39 | + const int64_t chunks_to_do = N * chunks_per_cloud; |
| 40 | + const int D = p1.size(2); |
| 41 | + |
| 42 | + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { |
| 43 | + const int64_t n = chunk / chunks_per_cloud; // batch_index |
| 44 | + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); |
| 45 | + int64_t i = start_point + threadIdx.x; |
| 46 | + |
| 47 | + // Check if point is valid in heterogeneous tensor |
| 48 | + if (i >= lengths1[n]) { |
| 49 | + continue; |
| 50 | + } |
| 51 | + |
| 52 | + // Iterate over points in p2 until desired count is reached or |
| 53 | + // all points have been considered |
| 54 | + for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) { |
| 55 | + // Calculate the distance between the points |
| 56 | + scalar_t dist2 = 0.0; |
| 57 | + for (int d = 0; d < D; ++d) { |
| 58 | + scalar_t diff = p1[n][i][d] - p2[n][j][d]; |
| 59 | + dist2 += (diff * diff); |
| 60 | + } |
| 61 | + |
| 62 | + if (dist2 < radius2) { |
| 63 | + // If the point is within the radius |
| 64 | + // Set the value of the index to the point index |
| 65 | + idxs[n][i][count] = j; |
| 66 | + dists[n][i][count] = dist2; |
| 67 | + |
| 68 | + // increment the number of selected samples for the point i |
| 69 | + ++count; |
| 70 | + } |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +std::tuple<at::Tensor, at::Tensor> BallQueryCuda( |
| 76 | + const at::Tensor& p1, // (N, P1, 3) |
| 77 | + const at::Tensor& p2, // (N, P2, 3) |
| 78 | + const at::Tensor& lengths1, // (N,) |
| 79 | + const at::Tensor& lengths2, // (N,) |
| 80 | + int K, |
| 81 | + float radius) { |
| 82 | + // Check inputs are on the same device |
| 83 | + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, |
| 84 | + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; |
| 85 | + at::CheckedFrom c = "BallQueryCuda"; |
| 86 | + at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); |
| 87 | + at::checkAllSameType(c, {p1_t, p2_t}); |
| 88 | + |
| 89 | + // Set the device for the kernel launch based on the device of p1 |
| 90 | + at::cuda::CUDAGuard device_guard(p1.device()); |
| 91 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 92 | + |
| 93 | + TORCH_CHECK( |
| 94 | + p2.size(2) == p1.size(2), "Point sets must have the same last dimension"); |
| 95 | + |
| 96 | + const int N = p1.size(0); |
| 97 | + const int P1 = p1.size(1); |
| 98 | + const int64_t K_64 = K; |
| 99 | + const float radius2 = radius * radius; |
| 100 | + |
| 101 | + // Output tensor with indices of neighbors for each point in p1 |
| 102 | + auto long_dtype = lengths1.options().dtype(at::kLong); |
| 103 | + auto idxs = at::full({N, P1, K}, -1, long_dtype); |
| 104 | + auto dists = at::zeros({N, P1, K}, p1.options()); |
| 105 | + |
| 106 | + if (idxs.numel() == 0) { |
| 107 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 108 | + return std::make_tuple(idxs, dists); |
| 109 | + } |
| 110 | + |
| 111 | + const size_t blocks = 256; |
| 112 | + const size_t threads = 256; |
| 113 | + |
| 114 | + AT_DISPATCH_FLOATING_TYPES( |
| 115 | + p1.scalar_type(), "ball_query_kernel_cuda", ([&] { |
| 116 | + BallQueryKernel<<<blocks, threads, 0, stream>>>( |
| 117 | + p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(), |
| 118 | + p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(), |
| 119 | + lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(), |
| 120 | + lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(), |
| 121 | + idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(), |
| 122 | + dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(), |
| 123 | + K_64, |
| 124 | + radius2); |
| 125 | + })); |
| 126 | + |
| 127 | + AT_CUDA_CHECK(cudaGetLastError()); |
| 128 | + |
| 129 | + return std::make_tuple(idxs, dists); |
| 130 | +} |
0 commit comments