Skip to content

Commit cb170ac

Browse files
bottlerfacebook-github-bot
authored andcommitted
Avoid torch/extension.h in cuda
Summary: Unlike other cu files, sigmoid_alpha_blend uses torch/extension.h. Avoid for possible build speed win and because of a reported problem #843 on windows with CUDA 11.4. Reviewed By: nikhilaravi Differential Revision: D31054121 fbshipit-source-id: 53a1f985a1695a044dfd2ee1a5b0adabdf280595
1 parent fe5bfa5 commit cb170ac

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

pytorch3d/csrc/blending/sigmoid_alpha_blend.cu

+28-28
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <ATen/ATen.h>
910
#include <ATen/cuda/CUDAContext.h>
1011
#include <c10/cuda/CUDAGuard.h>
11-
#include <torch/extension.h>
1212
#include <cmath>
1313
#include <vector>
1414

1515
template <typename scalar_t>
1616
__global__ void SigmoidAlphaBlendForwardKernel(
1717
// clang-format off
18-
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K)
19-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
20-
torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W)
18+
const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
19+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
20+
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
2121
// clang-format on
2222
const scalar_t sigma,
2323
const int N,
@@ -67,7 +67,7 @@ __global__ void SigmoidAlphaBlendForwardKernel(
6767
}
6868
}
6969

70-
torch::Tensor SigmoidAlphaBlendForwardCuda(
70+
at::Tensor SigmoidAlphaBlendForwardCuda(
7171
const at::Tensor& distances, // (N, H, W, K)
7272
const at::Tensor& pix_to_face, // (N, H, W, K)
7373
const float sigma) {
@@ -99,9 +99,9 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
9999
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
100100
// clang-format off
101101
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
102-
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
103-
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(),
104-
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
102+
distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
103+
pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
104+
alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
105105
sigma,
106106
N,
107107
H,
@@ -117,11 +117,11 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
117117
template <typename scalar_t>
118118
__global__ void SigmoidAlphaBlendBackwardKernel(
119119
// clang-format off
120-
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> grad_alphas, // (N, H, W)
121-
const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> alphas, // (N, H, W)
122-
const torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> distances, // (N, H, W, K)
123-
const torch::PackedTensorAccessor64<int64_t, 4, torch::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
124-
torch::PackedTensorAccessor64<scalar_t, 4, torch::RestrictPtrTraits> grad_distances, // (N, H, W)
120+
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> grad_alphas, // (N, H, W)
121+
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> alphas, // (N, H, W)
122+
const at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> distances, // (N, H, W, K)
123+
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
124+
at::PackedTensorAccessor64<scalar_t, 4, at::RestrictPtrTraits> grad_distances, // (N, H, W)
125125
// clang-format on
126126
const scalar_t sigma,
127127
const int N,
@@ -162,7 +162,7 @@ __global__ void SigmoidAlphaBlendBackwardKernel(
162162
}
163163
}
164164

165-
torch::Tensor SigmoidAlphaBlendBackwardCuda(
165+
at::Tensor SigmoidAlphaBlendBackwardCuda(
166166
const at::Tensor& grad_alphas, // (N, H, W)
167167
const at::Tensor& alphas, // (N, H, W)
168168
const at::Tensor& distances, // (N, H, W, K)
@@ -195,20 +195,20 @@ torch::Tensor SigmoidAlphaBlendBackwardCuda(
195195

196196
AT_DISPATCH_FLOATING_TYPES(
197197
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
198-
SigmoidAlphaBlendBackwardKernel<scalar_t>
199-
<<<blocks, threads, 0, stream>>>(
200-
// clang-format off
201-
grad_alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
202-
alphas.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
203-
distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
204-
pix_to_face.packed_accessor64<int64_t, 4, torch::RestrictPtrTraits>(),
205-
grad_distances.packed_accessor64<scalar_t, 4, torch::RestrictPtrTraits>(),
206-
// clang-format on
207-
sigma,
208-
N,
209-
H,
210-
W,
211-
K);
198+
SigmoidAlphaBlendBackwardKernel<
199+
scalar_t><<<blocks, threads, 0, stream>>>(
200+
// clang-format off
201+
grad_alphas.packed_accessor64<scalar_t, 3,at::RestrictPtrTraits>(),
202+
alphas.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
203+
distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
204+
pix_to_face.packed_accessor64<int64_t, 4, at::RestrictPtrTraits>(),
205+
grad_distances.packed_accessor64<scalar_t, 4, at::RestrictPtrTraits>(),
206+
// clang-format on
207+
sigma,
208+
N,
209+
H,
210+
W,
211+
K);
212212
}));
213213

214214
AT_CUDA_CHECK(cudaGetLastError());

0 commit comments

Comments
 (0)