Skip to content

Error building from source | Windows 10; torch 1.7.0; CUDA 11.4 #843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
SergeiSamuilov opened this issue Sep 15, 2021 · 8 comments
Closed
Assignees
Labels
installation Installation questions or issues

Comments

@SergeiSamuilov
Copy link

After I run python setup.py develop, I get the following error:
log_ninja_enabled.txt
log_ninja_disabled.txt

C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\ATen/core/ivalue_inl.h(389): warning C4101: 'e': unreferenced local variable
C:\pytorch3d-main\pytorch3d\csrc\sample_pdf\sample_pdf_cpu.cpp(35): error C2146: syntax error: missing ';' before identifier 'output_p'
C:\pytorch3d-main\pytorch3d\csrc\sample_pdf\sample_pdf_cpu.cpp(35): error C2065: 'output_p': undeclared identifier
C:\pytorch3d-main\pytorch3d\csrc\sample_pdf\sample_pdf_cpu.cpp(67): error C2065: 'output_p': undeclared identifier
C:\pytorch3d-main\pytorch3d\csrc\sample_pdf\sample_pdf_cpu.cpp(96): error C2065: 'output_p': undeclared identifier
C:\pytorch3d-main\pytorch3d\csrc\sample_pdf\sample_pdf_cpu.cpp(97): error C2065: 'output_p': undeclared identifier
error: command 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30037\bin\HostX86\x64\cl.exe' failed with exit status 2

Could you please give some clarification on what I'm doing wrong

@gkioxari gkioxari added the installation Installation questions or issues label Sep 15, 2021
@SergeiSamuilov
Copy link
Author

I also tried building pytorch3d version 0.4.0 as it doesn't have sample_pdf and got new error:

C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.4/include\thrust/system/cuda/config.h(78): fatal error C1189: #error: The version of CUB in your include path is not compatible with this release of Thrust. CUB is now included in the CUDA Toolkit, so you no longer need to use your own checkout of CUB. Define THRUST_IGNORE_CUB_VERSION_CHECK to ignore this.
error: command 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe' failed with exit status 2

@SergeiSamuilov
Copy link
Author

After I deleted environment variable CUB_HOME and built 0.4.0, the error is:

C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\detail/common.h(108): warning C4005: 'HAVE_SNPRINTF': macro redefinition
C:\Users\PC\AppData\Local\Programs\Python\Python38\include\pyerrors.h(315): note: see previous definition of 'HAVE_SNPRINTF'
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe -c C:\pytorch3d-0.4.0\pytorch3d\csrc\blending\sigmoid_alpha_blend.cu -o build\temp.win-amd64-3.8\Release\pytorch3d-0.4.0\pytorch3d\csrc\blending\sigmoid_alpha_blend.obj -IC:\pytorch3d-0.4.0\pytorch3d\csrc -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\torch\csrc\api\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\TH -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\THC "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\include" -IC:\Users\PC\AppData\Local\Programs\Python\Python38\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\include "-IC:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30037\include" "-IC:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\ucrt" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\shared" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\um" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\winrt" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\cppwinrt" -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -std=c++14 -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --use-local-env
sigmoid_alpha_blend.cu
C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\detail/common.h(108): warning C4005: 'HAVE_SNPRINTF': macro redefinition
C:\Users\PC\AppData\Local\Programs\Python\Python38\include\pyerrors.h(315): note: see previous definition of 'HAVE_SNPRINTF'
C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\cast.h(1429): error: too few arguments for template template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(1507): here

C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\cast.h(1503): error: too few arguments for template template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(1507): here

2 errors detected in the compilation of "c:/pytorch3d-0.4.0/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu".
error: command 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe' failed with exit status 1

@bottler
Copy link
Contributor

bottler commented Sep 16, 2021

Several things going on:

  • We haven't tried building at all with CUDA 11.4 yet, as the latest PyTorch (1.9.0) I think is designed to work with CUDA 10.2 and 11.1. PyTorch 1.7.0 which you are using had CUDA 11.0 as its newest version. But your errors don't look related to this.

  • The error in your first message looks like a problem with the use of __restrict__ in C++. __restrict__ is non standard and isn't known by MSVC. It may help just to delete the single occurrence of __restrict__ in sample_pdf_cpu.cpp . I actually think it might be a problem on some linux builds too. I plan to remove it from the code.

  • setup.py always sets THRUST_IGNORE_CUB_VERSION_CHECK if building with cuda support, so it is odd that you would get an error suggesting setting it.

  • I don't understand the pybind problem with sigmoid_alpha_blend. Is that with ninja?

@SergeiSamuilov
Copy link
Author

  1. Regarding the first message: I tried to build without restrict keyword
    And still got the same error.
    C:\pytorch3d-main\pytorch3d\csrc\sample_pdf\sample_pdf_cpu.cpp(35): error C2146: syntax error: missing ';' before identifier 'output_p'
    sample_pdf_cpu.txt
  2. It's with ninja disabled for more verbose output.

@bottler
Copy link
Contributor

bottler commented Sep 17, 2021

  1. There's something really weird going on. Line 35 unmodified is understandably a syntax error in MSVC. Removing the __restrict__ should leave line 35 looking like
float* output_p = outputs.data_ptr<float>() + start_batch * n_samples;

In your new version, you are missing the first *, which should be some sort of type error but the syntax would be correct. The syntax error suggests that the compiler is still seeing the old code, not your change.

  1. I reckon your cuda setup cannot parse torch.extension.h - it is complaining about our one .cu file which includes it. We used to be careful to exclude it - see 85c396f822 - because there were other problems with it on Linux with old compilers. You may be suffering because of your new cuda (e.g. there may be conditional compilation inside torch 1.7's included pybind version which is responding wrongly to the macros set by the new cuda compiler). It should be possible for you to change that file to exclude it - e.g. by switching torch::PackedTensorAccessor64 to at::PackedTensorAccessor64. We might make this change too.

@SergeiSamuilov
Copy link
Author

SergeiSamuilov commented Sep 20, 2021

  1. Error after I changed the line:

C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe -c C:\pytorch3d-main\pytorch3d\csrc\ball_query\ball_query.cu -o build\temp.win-amd64-3.8\Release\pytorch3d-main\pytorch3d\csrc\ball_query\ball_query.obj -IC:\pytorch3d-main\pytorch3d\csrc -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\torch\csrc\api\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\TH -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\THC "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\include" -IC:\Users\PC\AppData\Local\Programs\Python\Python38\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\include "-IC:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30037\include" "-IC:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\ucrt" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\shared" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\um" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\winrt" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\cppwinrt" -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -std=c++14 -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --use-local-env
ball_query.cu
C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\detail/common.h(108): warning C4005: 'HAVE_SNPRINTF': macro redefinition
C:\Users\PC\AppData\Local\Programs\Python\Python38\include\pyerrors.h(315): note: see previous definition of 'HAVE_SNPRINTF'
C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\cast.h(1429): error: too few arguments for template template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(1507): here

C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\cast.h(1503): error: too few arguments for template template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(1507): here

2 errors detected in the compilation of "c:/pytorch3d-main/pytorch3d/csrc/ball_query/ball_query.cu".
error: command 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe' failed with exit status 1

  1. Unfortunately, I can't downgrade CUDA, as there is another extension that requires version >11.1. Switching torch::PackedTensorAccessor64 to at::PackedTensorAccessor64 in sigmoid_alpha_blend.cu yields the same error.

sigmoid_alpha_blend_cpu.txt

C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe -c C:\pytorch3d-0.4.0\pytorch3d\csrc\blending\sigmoid_alpha_blend.cu -o build\temp.win-amd64-3.8\Release\pytorch3d-0.4.0\pytorch3d\csrc\blending\sigmoid_alpha_blend.obj -IC:\pytorch3d-0.4.0\pytorch3d\csrc -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\torch\csrc\api\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\TH -IC:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\THC "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\include" -IC:\Users\PC\AppData\Local\Programs\Python\Python38\include -IC:\Users\PC\AppData\Local\Programs\Python\Python38\include "-IC:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30037\include" "-IC:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\ucrt" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\shared" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\um" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\winrt" "-IC:\Program Files (x86)\Windows Kits\10\include\10.0.19041.0\cppwinrt" -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -std=c++14 -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --use-local-env
sigmoid_alpha_blend.cu
C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\detail/common.h(108): warning C4005: 'HAVE_SNPRINTF': macro redefinition
C:\Users\PC\AppData\Local\Programs\Python\Python38\include\pyerrors.h(315): note: see previous definition of 'HAVE_SNPRINTF'
C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\cast.h(1429): error: too few arguments for template template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(1507): here

C:\Users\PC\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\include\pybind11\cast.h(1503): error: too few arguments for template template parameter "Tuple"
detected during instantiation of class "pybind11::detail::tuple_caster<Tuple, Ts...> [with Tuple=std::pair, Ts=<T1, T2>]"
(1507): here

2 errors detected in the compilation of "c:/pytorch3d-0.4.0/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu".
error: command 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\nvcc.exe' failed with exit status 1

@bottler
Copy link
Contributor

bottler commented Sep 20, 2021

The change to sigmoid_alpha_blend.cu should look like this
leaving you with
sigmoid_alpha_blend.cu

@SergeiSamuilov
Copy link
Author

Yeah, it worked. Built 0.4.0. Thanks a lot!!!

@bottler bottler closed this as completed Sep 20, 2021
facebook-github-bot pushed a commit that referenced this issue Sep 22, 2021
Summary: Remove use of nonstandard C++. Noticed on windows in issue #843. (We use `__restrict__` in CUDA, where it is fine, even on windows)

Reviewed By: nikhilaravi

Differential Revision: D31006516

fbshipit-source-id: 929ba9b3216cb70fad3ffa3274c910618d83973f
facebook-github-bot pushed a commit that referenced this issue Sep 22, 2021
Summary: Unlike other cu files, sigmoid_alpha_blend uses torch/extension.h. Avoid for possible build speed win and because of a reported problem #843 on windows with CUDA 11.4.

Reviewed By: nikhilaravi

Differential Revision: D31054121

fbshipit-source-id: 53a1f985a1695a044dfd2ee1a5b0adabdf280595
@facebookresearch facebookresearch locked as resolved and limited conversation to collaborators Oct 6, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
installation Installation questions or issues
Projects
None yet
Development

No branches or pull requests

3 participants