6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < ATen/ATen.h>
9
10
#include < ATen/cuda/CUDAContext.h>
10
11
#include < c10/cuda/CUDAGuard.h>
11
- #include < torch/extension.h>
12
12
#include < cmath>
13
13
#include < vector>
14
14
15
15
template <typename scalar_t >
16
16
__global__ void SigmoidAlphaBlendForwardKernel (
17
17
// clang-format off
18
- const torch ::PackedTensorAccessor64<scalar_t , 4 , torch ::RestrictPtrTraits> distances, // (N, H, W, K)
19
- const torch ::PackedTensorAccessor64<int64_t , 4 , torch ::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
20
- torch ::PackedTensorAccessor64<scalar_t , 3 , torch ::RestrictPtrTraits> alphas, // (N, H, W)
18
+ const at ::PackedTensorAccessor64<scalar_t , 4 , at ::RestrictPtrTraits> distances, // (N, H, W, K)
19
+ const at ::PackedTensorAccessor64<int64_t , 4 , at ::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
20
+ at ::PackedTensorAccessor64<scalar_t , 3 , at ::RestrictPtrTraits> alphas, // (N, H, W)
21
21
// clang-format on
22
22
const scalar_t sigma,
23
23
const int N,
@@ -67,7 +67,7 @@ __global__ void SigmoidAlphaBlendForwardKernel(
67
67
}
68
68
}
69
69
70
- torch ::Tensor SigmoidAlphaBlendForwardCuda (
70
+ at ::Tensor SigmoidAlphaBlendForwardCuda (
71
71
const at::Tensor& distances, // (N, H, W, K)
72
72
const at::Tensor& pix_to_face, // (N, H, W, K)
73
73
const float sigma) {
@@ -99,9 +99,9 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
99
99
distances.scalar_type (), " sigmoid_alpha_blend_kernel" , ([&] {
100
100
// clang-format off
101
101
SigmoidAlphaBlendForwardKernel<scalar_t ><<<blocks, threads, 0 , stream>>> (
102
- distances.packed_accessor64 <scalar_t , 4 , torch ::RestrictPtrTraits>(),
103
- pix_to_face.packed_accessor64 <int64_t , 4 , torch ::RestrictPtrTraits>(),
104
- alphas.packed_accessor64 <scalar_t , 3 , torch ::RestrictPtrTraits>(),
102
+ distances.packed_accessor64 <scalar_t , 4 , at ::RestrictPtrTraits>(),
103
+ pix_to_face.packed_accessor64 <int64_t , 4 , at ::RestrictPtrTraits>(),
104
+ alphas.packed_accessor64 <scalar_t , 3 , at ::RestrictPtrTraits>(),
105
105
sigma,
106
106
N,
107
107
H,
@@ -117,11 +117,11 @@ torch::Tensor SigmoidAlphaBlendForwardCuda(
117
117
template <typename scalar_t >
118
118
__global__ void SigmoidAlphaBlendBackwardKernel (
119
119
// clang-format off
120
- const torch ::PackedTensorAccessor64<scalar_t , 3 , torch ::RestrictPtrTraits> grad_alphas, // (N, H, W)
121
- const torch ::PackedTensorAccessor64<scalar_t , 3 , torch ::RestrictPtrTraits> alphas, // (N, H, W)
122
- const torch ::PackedTensorAccessor64<scalar_t , 4 , torch ::RestrictPtrTraits> distances, // (N, H, W, K)
123
- const torch ::PackedTensorAccessor64<int64_t , 4 , torch ::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
124
- torch ::PackedTensorAccessor64<scalar_t , 4 , torch ::RestrictPtrTraits> grad_distances, // (N, H, W)
120
+ const at ::PackedTensorAccessor64<scalar_t , 3 , at ::RestrictPtrTraits> grad_alphas, // (N, H, W)
121
+ const at ::PackedTensorAccessor64<scalar_t , 3 , at ::RestrictPtrTraits> alphas, // (N, H, W)
122
+ const at ::PackedTensorAccessor64<scalar_t , 4 , at ::RestrictPtrTraits> distances, // (N, H, W, K)
123
+ const at ::PackedTensorAccessor64<int64_t , 4 , at ::RestrictPtrTraits> pix_to_face, // (N, H, W, K)
124
+ at ::PackedTensorAccessor64<scalar_t , 4 , at ::RestrictPtrTraits> grad_distances, // (N, H, W)
125
125
// clang-format on
126
126
const scalar_t sigma,
127
127
const int N,
@@ -162,7 +162,7 @@ __global__ void SigmoidAlphaBlendBackwardKernel(
162
162
}
163
163
}
164
164
165
- torch ::Tensor SigmoidAlphaBlendBackwardCuda (
165
+ at ::Tensor SigmoidAlphaBlendBackwardCuda (
166
166
const at::Tensor& grad_alphas, // (N, H, W)
167
167
const at::Tensor& alphas, // (N, H, W)
168
168
const at::Tensor& distances, // (N, H, W, K)
@@ -195,20 +195,20 @@ torch::Tensor SigmoidAlphaBlendBackwardCuda(
195
195
196
196
AT_DISPATCH_FLOATING_TYPES (
197
197
distances.scalar_type (), " sigmoid_alpha_blend_backward_kernel" , ([&] {
198
- SigmoidAlphaBlendBackwardKernel<scalar_t >
199
- <<<blocks, threads, 0 , stream>>> (
200
- // clang-format off
201
- grad_alphas.packed_accessor64 <scalar_t , 3 , torch ::RestrictPtrTraits>(),
202
- alphas.packed_accessor64 <scalar_t , 3 , torch ::RestrictPtrTraits>(),
203
- distances.packed_accessor64 <scalar_t , 4 , torch ::RestrictPtrTraits>(),
204
- pix_to_face.packed_accessor64 <int64_t , 4 , torch ::RestrictPtrTraits>(),
205
- grad_distances.packed_accessor64 <scalar_t , 4 , torch ::RestrictPtrTraits>(),
206
- // clang-format on
207
- sigma,
208
- N,
209
- H,
210
- W,
211
- K);
198
+ SigmoidAlphaBlendBackwardKernel<
199
+ scalar_t > <<<blocks, threads, 0 , stream>>> (
200
+ // clang-format off
201
+ grad_alphas.packed_accessor64 <scalar_t , 3 ,at ::RestrictPtrTraits>(),
202
+ alphas.packed_accessor64 <scalar_t , 3 , at ::RestrictPtrTraits>(),
203
+ distances.packed_accessor64 <scalar_t , 4 , at ::RestrictPtrTraits>(),
204
+ pix_to_face.packed_accessor64 <int64_t , 4 , at ::RestrictPtrTraits>(),
205
+ grad_distances.packed_accessor64 <scalar_t , 4 , at ::RestrictPtrTraits>(),
206
+ // clang-format on
207
+ sigma,
208
+ N,
209
+ H,
210
+ W,
211
+ K);
212
212
}));
213
213
214
214
AT_CUDA_CHECK (cudaGetLastError ());
0 commit comments