Skip to content

Commit eed68f4

Browse files
jcjohnsonfacebook-github-bot
authored andcommitted
Refactor mesh coarse rasterization
Summary: Renaming parts of the mesh coarse rasterization and separating the bounding box calculation. All in preparation for sharing code with point rasterization. Reviewed By: bottler Differential Revision: D30369112 fbshipit-source-id: 3508c0b1239b355030cfa4038d5f3d6a945ebbf4
1 parent 62dbf37 commit eed68f4

File tree

1 file changed

+151
-114
lines changed

1 file changed

+151
-114
lines changed

pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu

Lines changed: 151 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,55 @@
1717
#include "utils/float_math.cuh"
1818
#include "utils/geometry_utils.cuh" // For kEpsilon -- gross
1919

20-
// Get the xyz coordinates of the three vertices for the face given by the
21-
// index face_idx into face_verts.
22-
__device__ thrust::tuple<float3, float3, float3> GetSingleFaceVerts(
23-
const float* face_verts,
24-
int face_idx) {
25-
const float x0 = face_verts[face_idx * 9 + 0];
26-
const float y0 = face_verts[face_idx * 9 + 1];
27-
const float z0 = face_verts[face_idx * 9 + 2];
28-
const float x1 = face_verts[face_idx * 9 + 3];
29-
const float y1 = face_verts[face_idx * 9 + 4];
30-
const float z1 = face_verts[face_idx * 9 + 5];
31-
const float x2 = face_verts[face_idx * 9 + 6];
32-
const float y2 = face_verts[face_idx * 9 + 7];
33-
const float z2 = face_verts[face_idx * 9 + 8];
34-
35-
const float3 v0xyz = make_float3(x0, y0, z0);
36-
const float3 v1xyz = make_float3(x1, y1, z1);
37-
const float3 v2xyz = make_float3(x2, y2, z2);
38-
39-
return thrust::make_tuple(v0xyz, v1xyz, v2xyz);
20+
__global__ void TriangleBoundingBoxKernel(
21+
const float* face_verts, // (F, 3, 3)
22+
const int F,
23+
const float blur_radius,
24+
float* bboxes, // (4, F)
25+
bool* skip_face) { // (F,)
26+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
27+
const int num_threads = blockDim.x * gridDim.x;
28+
const float sqrt_radius = sqrt(blur_radius);
29+
for (int f = tid; f < F; f += num_threads) {
30+
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
31+
const float v0y = face_verts[f * 9 + 0 * 3 + 1];
32+
const float v0z = face_verts[f * 9 + 0 * 3 + 2];
33+
const float v1x = face_verts[f * 9 + 1 * 3 + 0];
34+
const float v1y = face_verts[f * 9 + 1 * 3 + 1];
35+
const float v1z = face_verts[f * 9 + 1 * 3 + 2];
36+
const float v2x = face_verts[f * 9 + 2 * 3 + 0];
37+
const float v2y = face_verts[f * 9 + 2 * 3 + 1];
38+
const float v2z = face_verts[f * 9 + 2 * 3 + 2];
39+
const float xmin = FloatMin3(v0x, v1x, v2x) - sqrt_radius;
40+
const float xmax = FloatMax3(v0x, v1x, v2x) + sqrt_radius;
41+
const float ymin = FloatMin3(v0y, v1y, v2y) - sqrt_radius;
42+
const float ymax = FloatMax3(v0y, v1y, v2y) + sqrt_radius;
43+
const float zmin = FloatMin3(v0z, v1z, v2z);
44+
const bool skip = zmin < kEpsilon;
45+
bboxes[0 * F + f] = xmin;
46+
bboxes[1 * F + f] = xmax;
47+
bboxes[2 * F + f] = ymin;
48+
bboxes[3 * F + f] = ymax;
49+
skip_face[f] = skip;
50+
}
4051
}
4152

42-
__global__ void RasterizeMeshesCoarseCudaKernel(
43-
const float* face_verts,
44-
const int64_t* mesh_to_face_first_idx,
45-
const int64_t* num_faces_per_mesh,
46-
const float blur_radius,
53+
__global__ void RasterizeCoarseCudaKernel(
54+
const float* bboxes, // (4, E) (xmin, xmax, ymin, ymax)
55+
const bool* should_skip, // (E,)
56+
const int64_t* elem_first_idxs,
57+
const int64_t* elems_per_batch,
4758
const int N,
48-
const int F,
59+
const int E,
4960
const int H,
5061
const int W,
5162
const int bin_size,
5263
const int chunk_size,
53-
const int max_faces_per_bin,
54-
int* faces_per_bin,
55-
int* bin_faces) {
64+
const int max_elem_per_bin,
65+
int* elems_per_bin,
66+
int* bin_elems) {
5667
extern __shared__ char sbuf[];
57-
const int M = max_faces_per_bin;
68+
const int M = max_elem_per_bin;
5869
// Integer divide round up
5970
const int num_bins_x = 1 + (W - 1) / bin_size;
6071
const int num_bins_y = 1 + (H - 1) / bin_size;
@@ -71,53 +82,39 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
7182
const float half_pix_y = NDC_y_half_range / H;
7283

7384
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
74-
// stored in shared memory that will track whether each point in the chunk
85+
// stored in shared memory that will track whether each elem in the chunk
7586
// falls into each bin of the image.
7687
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
7788

78-
// Have each block handle a chunk of faces
79-
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
89+
// Have each block handle a chunk of elements
90+
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
8091
const int num_chunks = N * chunks_per_batch;
8192

8293
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
8394
const int batch_idx = chunk / chunks_per_batch; // batch index
8495
const int chunk_idx = chunk % chunks_per_batch;
85-
const int face_start_idx = chunk_idx * chunk_size;
96+
const int elem_chunk_start_idx = chunk_idx * chunk_size;
8697

8798
binmask.block_clear();
88-
const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx];
89-
const int64_t mesh_face_stop_idx =
90-
mesh_face_start_idx + num_faces_per_mesh[batch_idx];
99+
const int64_t elem_start_idx = elem_first_idxs[batch_idx];
100+
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
91101

92102
// Have each thread handle a different face within the chunk
93-
for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) {
94-
const int f_idx = face_start_idx + f;
103+
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
104+
const int e_idx = elem_chunk_start_idx + e;
95105

96-
// Check if face index corresponds to the mesh in the batch given by
97-
// batch_idx
98-
if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) {
106+
// Check that we are still within the same element of the batch
107+
if (e_idx >= elem_stop_idx || e_idx < elem_start_idx) {
99108
continue;
100109
}
101110

102-
// Get xyz coordinates of the three face vertices.
103-
const auto v012 = GetSingleFaceVerts(face_verts, f_idx);
104-
const float3 v0 = thrust::get<0>(v012);
105-
const float3 v1 = thrust::get<1>(v012);
106-
const float3 v2 = thrust::get<2>(v012);
107-
108-
// Compute screen-space bbox for the triangle expanded by blur.
109-
float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius);
110-
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
111-
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
112-
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
113-
float zmin = FloatMin3(v0.z, v1.z, v2.z);
114-
115-
// Faces with at least one vertex behind the camera won't render
116-
// correctly and should be removed or clipped before calling the
117-
// rasterizer
118-
if (zmin < kEpsilon) {
111+
if (should_skip[e_idx]) {
119112
continue;
120113
}
114+
const float xmin = bboxes[0 * E + e_idx];
115+
const float xmax = bboxes[1 * E + e_idx];
116+
const float ymin = bboxes[2 * E + e_idx];
117+
const float ymax = bboxes[3 * E + e_idx];
121118

122119
// Brute-force search over all bins; TODO(T54294966) something smarter.
123120
for (int by = 0; by < num_bins_y; ++by) {
@@ -141,39 +138,39 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
141138

142139
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
143140
if (y_overlap && x_overlap) {
144-
binmask.set(by, bx, f);
141+
binmask.set(by, bx, e);
145142
}
146143
}
147144
}
148145
}
149146
__syncthreads();
150-
// Now we have processed every face in the current chunk. We need to
151-
// count the number of faces in each bin so we can write the indices
147+
// Now we have processed every elem in the current chunk. We need to
148+
// count the number of elems in each bin so we can write the indices
152149
// out to global memory. We have each thread handle a different bin.
153150
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
154151
byx += blockDim.x) {
155152
const int by = byx / num_bins_x;
156153
const int bx = byx % num_bins_x;
157154
const int count = binmask.count(by, bx);
158-
const int faces_per_bin_idx =
155+
const int elems_per_bin_idx =
159156
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
160157

161-
// This atomically increments the (global) number of faces found
158+
// This atomically increments the (global) number of elems found
162159
// in the current bin, and gets the previous value of the counter;
163160
// this effectively allocates space in the bin_faces array for the
164-
// faces in the current chunk that fall into this bin.
165-
const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count);
161+
// elems in the current chunk that fall into this bin.
162+
const int start = atomicAdd(elems_per_bin + elems_per_bin_idx, count);
166163

167164
// Now loop over the binmask and write the active bits for this bin
168165
// out to bin_faces.
169166
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
170167
by * num_bins_x * M + bx * M + start;
171-
for (int f = 0; f < chunk_size; ++f) {
172-
if (binmask.get(by, bx, f)) {
168+
for (int e = 0; e < chunk_size; ++e) {
169+
if (binmask.get(by, bx, e)) {
173170
// TODO(T54296346) find the correct method for handling errors in
174171
// CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin.
175172
// Either decrease bin size or increase max_faces_per_bin
176-
bin_faces[next_idx] = face_start_idx + f;
173+
bin_elems[next_idx] = elem_chunk_start_idx + e;
177174
next_idx++;
178175
}
179176
}
@@ -182,6 +179,69 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
182179
}
183180
}
184181

182+
at::Tensor RasterizeCoarseCuda(
183+
const at::Tensor& bboxes,
184+
const at::Tensor& should_skip,
185+
const at::Tensor& elem_first_idxs,
186+
const at::Tensor& elems_per_batch,
187+
const std::tuple<int, int> image_size,
188+
const int bin_size,
189+
const int max_elems_per_bin) {
190+
// Set the device for the kernel launch based on the device of the input
191+
at::cuda::CUDAGuard device_guard(bboxes.device());
192+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
193+
194+
const int H = std::get<0>(image_size);
195+
const int W = std::get<1>(image_size);
196+
197+
const int E = bboxes.size(1);
198+
const int N = elems_per_batch.size(0);
199+
const int M = max_elems_per_bin;
200+
201+
// Integer divide round up
202+
const int num_bins_y = 1 + (H - 1) / bin_size;
203+
const int num_bins_x = 1 + (W - 1) / bin_size;
204+
205+
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
206+
std::stringstream ss;
207+
ss << "In RasterizeCoarseCuda got num_bins_y: " << num_bins_y
208+
<< ", num_bins_x: " << num_bins_x << ", "
209+
<< "; that's too many!";
210+
AT_ERROR(ss.str());
211+
}
212+
auto opts = elems_per_batch.options().dtype(at::kInt);
213+
at::Tensor elems_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
214+
at::Tensor bin_elems = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
215+
216+
if (bin_elems.numel() == 0) {
217+
AT_CUDA_CHECK(cudaGetLastError());
218+
return bin_elems;
219+
}
220+
221+
const int chunk_size = 512;
222+
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
223+
const size_t blocks = 64;
224+
const size_t threads = 512;
225+
226+
RasterizeCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
227+
bboxes.contiguous().data_ptr<float>(),
228+
should_skip.contiguous().data_ptr<bool>(),
229+
elem_first_idxs.contiguous().data_ptr<int64_t>(),
230+
elems_per_batch.contiguous().data_ptr<int64_t>(),
231+
N,
232+
E,
233+
H,
234+
W,
235+
bin_size,
236+
chunk_size,
237+
M,
238+
elems_per_bin.data_ptr<int32_t>(),
239+
bin_elems.data_ptr<int32_t>());
240+
241+
AT_CUDA_CHECK(cudaGetLastError());
242+
return bin_elems;
243+
}
244+
185245
__global__ void RasterizePointsCoarseCudaKernel(
186246
const float* points, // (P, 3)
187247
const int64_t* cloud_to_packed_first_idx, // (N)
@@ -352,55 +412,32 @@ at::Tensor RasterizeMeshesCoarseCuda(
352412
at::cuda::CUDAGuard device_guard(face_verts.device());
353413
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
354414

355-
const int H = std::get<0>(image_size);
356-
const int W = std::get<1>(image_size);
357-
415+
// Allocate tensors for bboxes and should_skip
358416
const int F = face_verts.size(0);
359-
const int N = num_faces_per_mesh.size(0);
360-
const int M = max_faces_per_bin;
361-
362-
// Integer divide round up.
363-
const int num_bins_y = 1 + (H - 1) / bin_size;
364-
const int num_bins_x = 1 + (W - 1) / bin_size;
365-
366-
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
367-
std::stringstream ss;
368-
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
369-
<< ", num_bins_x: " << num_bins_x << ", "
370-
<< "; that's too many!";
371-
AT_ERROR(ss.str());
372-
}
373-
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
374-
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
375-
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
376-
377-
if (bin_faces.numel() == 0) {
378-
AT_CUDA_CHECK(cudaGetLastError());
379-
return bin_faces;
380-
}
381-
382-
const int chunk_size = 512;
383-
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
384-
const size_t blocks = 64;
385-
const size_t threads = 512;
386-
387-
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
417+
auto float_opts = face_verts.options().dtype(at::kFloat);
418+
auto bool_opts = face_verts.options().dtype(at::kBool);
419+
at::Tensor bboxes = at::empty({4, F}, float_opts);
420+
at::Tensor should_skip = at::empty({F}, bool_opts);
421+
422+
// Launch kernel to compute triangle bboxes
423+
const size_t blocks = 128;
424+
const size_t threads = 256;
425+
TriangleBoundingBoxKernel<<<blocks, threads, 0, stream>>>(
388426
face_verts.contiguous().data_ptr<float>(),
389-
mesh_to_face_first_idx.contiguous().data_ptr<int64_t>(),
390-
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
391-
blur_radius,
392-
N,
393427
F,
394-
H,
395-
W,
396-
bin_size,
397-
chunk_size,
398-
M,
399-
faces_per_bin.data_ptr<int32_t>(),
400-
bin_faces.data_ptr<int32_t>());
401-
428+
blur_radius,
429+
bboxes.contiguous().data_ptr<float>(),
430+
should_skip.contiguous().data_ptr<bool>());
402431
AT_CUDA_CHECK(cudaGetLastError());
403-
return bin_faces;
432+
433+
return RasterizeCoarseCuda(
434+
bboxes,
435+
should_skip,
436+
mesh_to_face_first_idx,
437+
num_faces_per_mesh,
438+
image_size,
439+
bin_size,
440+
max_faces_per_bin);
404441
}
405442

406443
at::Tensor RasterizePointsCoarseCuda(

0 commit comments

Comments
 (0)