Skip to content

Commit bbc7573

Browse files
jcjohnsonfacebook-github-bot
authored andcommitted
Unify coarse rasterization for points and meshes
Summary: There has historically been a lot of duplication between the coarse rasterization logic for point clouds and meshes. This diff factors out the shared logic, so coarse rasterization of point clouds and meshes share the same core logic. Previously the only difference between the coarse rasterization kernels for points and meshes was the logic for checking whether a {point / triangle} intersects a tile in the image. We implement a generic coarse rasterization kernel that takes a set of 2D bounding boxes rather than geometric primitives; we then implement separate kernels that compute 2D bounding boxes for points and triangles. This change does not affect the Python API at all. It also should not change any rasterization behavior, since this diff is just a refactoring of the existing logic. I see this diff as the first in a few pieces of rasterizer refactoring. Followup diffs should do the following: - Add a check for bin overflow in the generic coarse rasterizer kernel: allocate a global scalar to flag bin overflow which kernel worker threads can write to in case they detect bin overflow. The C++ launcher function can then check this flag after the kernel returns and issue a warning to the user in case of overflow. - As a slightly more involved mechanism, if bin overflow is detected then the coarse kernel can continue running in order to count how many elements fall into each bin, without actually writing out their indices to the coarse output tensor. Then the actual number of entries per bin can be used to re-allocate the output tensor and re-run the coarse rasterization kernel so that bin overflow can be automatically avoided. - The unification of the coarse and fine rasterization kernels also allows us to insert an extra CUDA kernel prior to coarse rasterization that filters out primitives outside the view frustum. This would be helpful for rendering full scenes (e.g. Matterport data) where only a small piece of the mesh is actually visible at any one time. Reviewed By: bottler Differential Revision: D25710361 fbshipit-source-id: 9c9dea512cb339c42adb3c92e7733fedd586ce1b
1 parent eed68f4 commit bbc7573

File tree

1 file changed

+45
-190
lines changed

1 file changed

+45
-190
lines changed

pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu

Lines changed: 45 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,29 @@ __global__ void TriangleBoundingBoxKernel(
5050
}
5151
}
5252

53+
__global__ void PointBoundingBoxKernel(
54+
const float* points, // (P, 3)
55+
const float* radius, // (P,)
56+
const int P,
57+
float* bboxes, // (4, P)
58+
bool* skip_points) {
59+
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
60+
const int num_threads = blockDim.x * gridDim.x;
61+
for (int p = tid; p < P; p += num_threads) {
62+
const float x = points[p * 3 + 0];
63+
const float y = points[p * 3 + 1];
64+
const float z = points[p * 3 + 2];
65+
const float r = radius[p];
66+
// TODO: change to kEpsilon to match triangles?
67+
const bool skip = z < 0;
68+
bboxes[0 * P + p] = x - r;
69+
bboxes[1 * P + p] = x + r;
70+
bboxes[2 * P + p] = y - r;
71+
bboxes[3 * P + p] = y + r;
72+
skip_points[p] = skip;
73+
}
74+
}
75+
5376
__global__ void RasterizeCoarseCudaKernel(
5477
const float* bboxes, // (4, E) (xmin, xmax, ymin, ymax)
5578
const bool* should_skip, // (E,)
@@ -242,150 +265,6 @@ at::Tensor RasterizeCoarseCuda(
242265
return bin_elems;
243266
}
244267

245-
__global__ void RasterizePointsCoarseCudaKernel(
246-
const float* points, // (P, 3)
247-
const int64_t* cloud_to_packed_first_idx, // (N)
248-
const int64_t* num_points_per_cloud, // (N)
249-
const float* radius,
250-
const int N,
251-
const int P,
252-
const int H,
253-
const int W,
254-
const int bin_size,
255-
const int chunk_size,
256-
const int max_points_per_bin,
257-
int* points_per_bin,
258-
int* bin_points) {
259-
extern __shared__ char sbuf[];
260-
const int M = max_points_per_bin;
261-
262-
// Integer divide round up
263-
const int num_bins_x = 1 + (W - 1) / bin_size;
264-
const int num_bins_y = 1 + (H - 1) / bin_size;
265-
266-
// NDC range depends on the ratio of W/H
267-
// The shorter side from (H, W) is given an NDC range of 2.0 and
268-
// the other side is scaled by the ratio of H:W.
269-
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
270-
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
271-
272-
// Size of half a pixel in NDC units is the NDC half range
273-
// divided by the corresponding image dimension
274-
const float half_pix_x = NDC_x_half_range / W;
275-
const float half_pix_y = NDC_y_half_range / H;
276-
277-
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
278-
// stored in shared memory that will track whether each point in the chunk
279-
// falls into each bin of the image.
280-
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
281-
282-
// Have each block handle a chunk of points and build a 3D bitmask in
283-
// shared memory to mark which points hit which bins. In this first phase,
284-
// each thread processes one point at a time. After processing the chunk,
285-
// one thread is assigned per bin, and the thread counts and writes the
286-
// points for the bin out to global memory.
287-
const int chunks_per_batch = 1 + (P - 1) / chunk_size;
288-
const int num_chunks = N * chunks_per_batch;
289-
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
290-
const int batch_idx = chunk / chunks_per_batch;
291-
const int chunk_idx = chunk % chunks_per_batch;
292-
const int point_start_idx = chunk_idx * chunk_size;
293-
294-
binmask.block_clear();
295-
296-
// Using the batch index of the thread get the start and stop
297-
// indices for the points.
298-
const int64_t cloud_point_start_idx = cloud_to_packed_first_idx[batch_idx];
299-
const int64_t cloud_point_stop_idx =
300-
cloud_point_start_idx + num_points_per_cloud[batch_idx];
301-
302-
// Have each thread handle a different point within the chunk
303-
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
304-
const int p_idx = point_start_idx + p;
305-
306-
// Check if point index corresponds to the cloud in the batch given by
307-
// batch_idx.
308-
if (p_idx >= cloud_point_stop_idx || p_idx < cloud_point_start_idx) {
309-
continue;
310-
}
311-
312-
const float px = points[p_idx * 3 + 0];
313-
const float py = points[p_idx * 3 + 1];
314-
const float pz = points[p_idx * 3 + 2];
315-
const float p_radius = radius[p_idx];
316-
if (pz < 0)
317-
continue; // Don't render points behind the camera.
318-
const float px0 = px - p_radius;
319-
const float px1 = px + p_radius;
320-
const float py0 = py - p_radius;
321-
const float py1 = py + p_radius;
322-
323-
// Brute-force search over all bins; TODO something smarter?
324-
// For example we could compute the exact bin where the point falls,
325-
// then check neighboring bins. This way we wouldn't have to check
326-
// all bins (however then we might have more warp divergence?)
327-
for (int by = 0; by < num_bins_y; ++by) {
328-
// Get y extent for the bin. PixToNonSquareNdc gives us the location of
329-
// the center of each pixel, so we need to add/subtract a half
330-
// pixel to get the true extent of the bin.
331-
const float by0 = PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
332-
const float by1 =
333-
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
334-
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
335-
336-
if (!y_overlap) {
337-
continue;
338-
}
339-
for (int bx = 0; bx < num_bins_x; ++bx) {
340-
// Get x extent for the bin; again we need to adjust the
341-
// output of PixToNonSquareNdc by half a pixel.
342-
const float bx0 = PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
343-
const float bx1 =
344-
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
345-
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
346-
347-
if (x_overlap) {
348-
binmask.set(by, bx, p);
349-
}
350-
}
351-
}
352-
}
353-
__syncthreads();
354-
// Now we have processed every point in the current chunk. We need to
355-
// count the number of points in each bin so we can write the indices
356-
// out to global memory. We have each thread handle a different bin.
357-
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
358-
byx += blockDim.x) {
359-
const int by = byx / num_bins_x;
360-
const int bx = byx % num_bins_x;
361-
const int count = binmask.count(by, bx);
362-
const int points_per_bin_idx =
363-
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
364-
365-
// This atomically increments the (global) number of points found
366-
// in the current bin, and gets the previous value of the counter;
367-
// this effectively allocates space in the bin_points array for the
368-
// points in the current chunk that fall into this bin.
369-
const int start = atomicAdd(points_per_bin + points_per_bin_idx, count);
370-
371-
// Now loop over the binmask and write the active bits for this bin
372-
// out to bin_points.
373-
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
374-
by * num_bins_x * M + bx * M + start;
375-
for (int p = 0; p < chunk_size; ++p) {
376-
if (binmask.get(by, bx, p)) {
377-
// TODO: Throw an error if next_idx >= M -- this means that
378-
// we got more than max_points_per_bin in this bin
379-
// TODO: check if atomicAdd is needed in line 265.
380-
bin_points[next_idx] = point_start_idx + p;
381-
next_idx++;
382-
}
383-
}
384-
}
385-
__syncthreads();
386-
}
387-
}
388-
389268
at::Tensor RasterizeMeshesCoarseCuda(
390269
const at::Tensor& face_verts,
391270
const at::Tensor& mesh_to_face_first_idx,
@@ -442,8 +321,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
442321

443322
at::Tensor RasterizePointsCoarseCuda(
444323
const at::Tensor& points, // (P, 3)
445-
const at::Tensor& cloud_to_packed_first_idx, // (N)
446-
const at::Tensor& num_points_per_cloud, // (N)
324+
const at::Tensor& cloud_to_packed_first_idx, // (N,)
325+
const at::Tensor& num_points_per_cloud, // (N,)
447326
const std::tuple<int, int> image_size,
448327
const at::Tensor& radius,
449328
const int bin_size,
@@ -465,54 +344,30 @@ at::Tensor RasterizePointsCoarseCuda(
465344
at::cuda::CUDAGuard device_guard(points.device());
466345
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
467346

468-
const int H = std::get<0>(image_size);
469-
const int W = std::get<1>(image_size);
470-
347+
// Allocate tensors for bboxes and should_skip
471348
const int P = points.size(0);
472-
const int N = num_points_per_cloud.size(0);
473-
const int M = max_points_per_bin;
474-
475-
// Integer divide round up.
476-
const int num_bins_y = 1 + (H - 1) / bin_size;
477-
const int num_bins_x = 1 + (W - 1) / bin_size;
349+
auto float_opts = points.options().dtype(at::kFloat);
350+
auto bool_opts = points.options().dtype(at::kBool);
351+
at::Tensor bboxes = at::empty({4, P}, float_opts);
352+
at::Tensor should_skip = at::empty({P}, bool_opts);
478353

479-
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
480-
// Make sure we do not use too much shared memory.
481-
std::stringstream ss;
482-
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
483-
<< ", num_bins_x: " << num_bins_x << ", "
484-
<< "; that's too many!";
485-
AT_ERROR(ss.str());
486-
}
487-
auto opts = num_points_per_cloud.options().dtype(at::kInt);
488-
at::Tensor points_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
489-
at::Tensor bin_points = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
490-
491-
if (bin_points.numel() == 0) {
492-
AT_CUDA_CHECK(cudaGetLastError());
493-
return bin_points;
494-
}
495-
496-
const int chunk_size = 512;
497-
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
498-
const size_t blocks = 64;
499-
const size_t threads = 512;
500-
501-
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
354+
// Launch kernel to compute point bboxes
355+
const size_t blocks = 128;
356+
const size_t threads = 256;
357+
PointBoundingBoxKernel<<<blocks, threads, 0, stream>>>(
502358
points.contiguous().data_ptr<float>(),
503-
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
504-
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
505359
radius.contiguous().data_ptr<float>(),
506-
N,
507360
P,
508-
H,
509-
W,
510-
bin_size,
511-
chunk_size,
512-
M,
513-
points_per_bin.contiguous().data_ptr<int32_t>(),
514-
bin_points.contiguous().data_ptr<int32_t>());
515-
361+
bboxes.contiguous().data_ptr<float>(),
362+
should_skip.contiguous().data_ptr<bool>());
516363
AT_CUDA_CHECK(cudaGetLastError());
517-
return bin_points;
364+
365+
return RasterizeCoarseCuda(
366+
bboxes,
367+
should_skip,
368+
cloud_to_packed_first_idx,
369+
num_points_per_cloud,
370+
image_size,
371+
bin_size,
372+
max_points_per_bin);
518373
}

0 commit comments

Comments
 (0)