Skip to content

Fix windows build #1689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pytorch3d/csrc/marching_cubes/marching_cubes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
compactedVoxelArray,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
voxelOccupied,
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
voxelOccupiedScan,
uint numVoxels) {
uint id = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -255,7 +255,7 @@ __global__ void GenerateFacesKernel(
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
compactedVoxelArray,
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> numVertsScanned,
const uint activeVoxels,
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
Expand Down Expand Up @@ -471,7 +471,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});

// number of active voxels
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();

const int device_id = vol.device().index();
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
Expand All @@ -492,7 +492,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
numVoxels);
AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize();
Expand All @@ -502,7 +502,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});

// total number of vertices
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();

// Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels.
Expand All @@ -522,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
activeVoxels,
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
Expand Down