13
13
#include < thrust/scan.h>
14
14
#include < cstdio>
15
15
#include " marching_cubes/tables.h"
16
- #include " utils/pytorch3d_cutils.h"
17
16
18
17
/*
19
18
Parallelized marching cubes for pytorch extension
@@ -267,13 +266,12 @@ __global__ void CompactVoxelsKernel(
267
266
// isolevel: threshold to determine isosurface intersection
268
267
//
269
268
__global__ void GenerateFacesKernel (
270
- torch ::PackedTensorAccessor32<float , 2 , torch ::RestrictPtrTraits> verts,
271
- torch ::PackedTensorAccessor<int64_t , 2 , torch ::RestrictPtrTraits> faces,
272
- torch ::PackedTensorAccessor<int64_t , 1 , torch ::RestrictPtrTraits> ids,
273
- torch ::PackedTensorAccessor32<int , 1 , torch ::RestrictPtrTraits>
269
+ at ::PackedTensorAccessor32<float , 2 , at ::RestrictPtrTraits> verts,
270
+ at ::PackedTensorAccessor<int64_t , 2 , at ::RestrictPtrTraits> faces,
271
+ at ::PackedTensorAccessor<int64_t , 1 , at ::RestrictPtrTraits> ids,
272
+ at ::PackedTensorAccessor32<int , 1 , at ::RestrictPtrTraits>
274
273
compactedVoxelArray,
275
- torch::PackedTensorAccessor32<int , 1 , torch::RestrictPtrTraits>
276
- numVertsScanned,
274
+ at::PackedTensorAccessor32<int , 1 , at::RestrictPtrTraits> numVertsScanned,
277
275
const uint activeVoxels,
278
276
const at::PackedTensorAccessor32<float , 3 , at::RestrictPtrTraits> vol,
279
277
const at::PackedTensorAccessor32<int , 2 , at::RestrictPtrTraits> faceTable,
@@ -436,15 +434,15 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
436
434
cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
437
435
438
436
// transfer _FACE_TABLE data to device
439
- torch ::Tensor face_table_tensor = torch ::zeros (
440
- {256 , 16 }, torch ::TensorOptions ().dtype (at::kInt ).device (at::kCPU ));
437
+ at ::Tensor face_table_tensor = at ::zeros (
438
+ {256 , 16 }, at ::TensorOptions ().dtype (at::kInt ).device (at::kCPU ));
441
439
auto face_table_a = face_table_tensor.accessor <int , 2 >();
442
440
for (int i = 0 ; i < 256 ; i++) {
443
441
for (int j = 0 ; j < 16 ; j++) {
444
442
face_table_a[i][j] = _FACE_TABLE[i][j];
445
443
}
446
444
}
447
- torch ::Tensor faceTable = face_table_tensor.to (vol.device ());
445
+ at ::Tensor faceTable = face_table_tensor.to (vol.device ());
448
446
449
447
// get numVoxels
450
448
int threads = 128 ;
@@ -458,10 +456,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
458
456
}
459
457
460
458
auto d_voxelVerts =
461
- torch ::zeros ({numVoxels}, torch ::TensorOptions ().dtype (at::kInt ))
459
+ at ::zeros ({numVoxels}, at ::TensorOptions ().dtype (at::kInt ))
462
460
.to (vol.device ());
463
461
auto d_voxelOccupied =
464
- torch ::zeros ({numVoxels}, torch ::TensorOptions ().dtype (at::kInt ))
462
+ at ::zeros ({numVoxels}, at ::TensorOptions ().dtype (at::kInt ))
465
463
.to (vol.device ());
466
464
467
465
// Execute "ClassifyVoxelKernel" kernel to precompute
@@ -480,7 +478,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
480
478
// If the number of active voxels is 0, return zero tensor for verts and
481
479
// faces.
482
480
auto d_voxelOccupiedScan =
483
- torch ::zeros ({numVoxels}, torch ::TensorOptions ().dtype (at::kInt ))
481
+ at ::zeros ({numVoxels}, at ::TensorOptions ().dtype (at::kInt ))
484
482
.to (vol.device ());
485
483
ThrustScanWrapper (
486
484
d_voxelOccupiedScan.data_ptr <int >(),
@@ -493,23 +491,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
493
491
int activeVoxels = lastElement + lastScan;
494
492
495
493
const int device_id = vol.device ().index ();
496
- auto opt =
497
- torch::TensorOptions ().dtype (torch::kInt ).device (torch::kCUDA , device_id);
498
- auto opt_long = torch::TensorOptions ()
499
- .dtype (torch::kInt64 )
500
- .device (torch::kCUDA , device_id);
494
+ auto opt = at::TensorOptions ().dtype (at::kInt ).device (at::kCUDA , device_id);
495
+ auto opt_long =
496
+ at::TensorOptions ().dtype (at::kLong ).device (at::kCUDA , device_id);
501
497
502
498
if (activeVoxels == 0 ) {
503
499
int ntris = 0 ;
504
- torch ::Tensor verts = torch ::zeros ({ntris * 3 , 3 }, vol.options ());
505
- torch ::Tensor faces = torch ::zeros ({ntris, 3 }, opt_long);
506
- torch ::Tensor ids = torch ::zeros ({ntris}, opt_long);
500
+ at ::Tensor verts = at ::zeros ({ntris * 3 , 3 }, vol.options ());
501
+ at ::Tensor faces = at ::zeros ({ntris, 3 }, opt_long);
502
+ at ::Tensor ids = at ::zeros ({ntris}, opt_long);
507
503
return std::make_tuple (verts, faces, ids);
508
504
}
509
505
510
506
// Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
511
507
// This allows us to run triangle generation on only the occupied voxels.
512
- auto d_compVoxelArray = torch ::zeros ({activeVoxels}, opt);
508
+ auto d_compVoxelArray = at ::zeros ({activeVoxels}, opt);
513
509
CompactVoxelsKernel<<<grid, threads, 0 , stream>>> (
514
510
d_compVoxelArray.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
515
511
d_voxelOccupied.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
@@ -519,7 +515,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
519
515
cudaDeviceSynchronize ();
520
516
521
517
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
522
- auto d_voxelVertsScan = torch ::zeros ({numVoxels}, opt);
518
+ auto d_voxelVertsScan = at ::zeros ({numVoxels}, opt);
523
519
ThrustScanWrapper (
524
520
d_voxelVertsScan.data_ptr <int >(),
525
521
d_voxelVerts.data_ptr <int >(),
@@ -533,10 +529,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
533
529
// Execute "GenerateFacesKernel" kernel
534
530
// This runs only on the occupied voxels.
535
531
// It looks up the field values and generates the triangle data.
536
- torch ::Tensor verts = torch ::zeros ({totalVerts, 3 }, vol.options ());
537
- torch ::Tensor faces = torch ::zeros ({totalVerts / 3 , 3 }, opt_long);
532
+ at ::Tensor verts = at ::zeros ({totalVerts, 3 }, vol.options ());
533
+ at ::Tensor faces = at ::zeros ({totalVerts / 3 , 3 }, opt_long);
538
534
539
- torch ::Tensor ids = torch ::zeros ({totalVerts}, opt_long);
535
+ at ::Tensor ids = at ::zeros ({totalVerts}, opt_long);
540
536
541
537
dim3 grid2 ((activeVoxels + threads - 1 ) / threads, 1 , 1 );
542
538
if (grid2.x > 65535 ) {
0 commit comments