@@ -22,28 +22,12 @@ def _joined_isfile(dirpath: str, basename: str) -> bool:
2222 return os .path .isfile (os .path .join (dirpath , basename ))
2323
2424
25- def _find_nvshmem_header_directory () -> Optional [str ]:
26- if IS_WINDOWS :
27- # nvshmem has no Windows support.
28- return None
29-
25+ def _find_under_site_packages (sub_dir : str , h_basename : str ) -> Optional [str ]:
3026 # Installed from a wheel
31- nvidia_sub_dirs = ("nvidia" , "nvshmem" , "include" )
3227 hdr_dir : str # help mypy
33- for hdr_dir in find_sub_dirs_all_sitepackages (nvidia_sub_dirs ):
34- if _joined_isfile (hdr_dir , "nvshmem.h" ):
35- return hdr_dir
36-
37- conda_prefix = os .environ .get ("CONDA_PREFIX" )
38- if conda_prefix and os .path .isdir (conda_prefix ):
39- hdr_dir = os .path .join (conda_prefix , "include" )
40- if _joined_isfile (hdr_dir , "nvshmem.h" ):
41- return hdr_dir
42-
43- for hdr_dir in sorted (glob .glob ("/usr/include/nvshmem_*" ), reverse = True ):
44- if _joined_isfile (hdr_dir , "nvshmem.h" ):
28+ for hdr_dir in find_sub_dirs_all_sitepackages (tuple (sub_dir .split ("/" ))):
29+ if _joined_isfile (hdr_dir , h_basename ):
4530 return hdr_dir
46-
4731 return None
4832
4933
@@ -54,6 +38,13 @@ def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str)
5438 parts .append ("include" )
5539 idir = os .path .join (* parts )
5640 if libname == "cccl" :
41+ if IS_WINDOWS :
42+ cdir_ctk12 = os .path .join (idir , "targets" , "x64" ) # conda has this anomaly
43+ cdir_ctk13 = os .path .join (cdir_ctk12 , "cccl" )
44+ if _joined_isfile (cdir_ctk13 , h_basename ):
45+ return cdir_ctk13
46+ if _joined_isfile (cdir_ctk12 , h_basename ):
47+ return cdir_ctk12
5748 cdir = os .path .join (idir , "cccl" ) # CTK 13
5849 if _joined_isfile (cdir , h_basename ):
5950 return cdir
@@ -62,38 +53,40 @@ def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str)
6253 return None
6354
6455
65- def _find_based_on_conda_layout (libname : str , h_basename : str , conda_prefix : str ) -> Optional [str ]:
56+ def _find_based_on_conda_layout (libname : str , h_basename : str , ctk_layout : bool ) -> Optional [str ]:
57+ conda_prefix = os .environ .get ("CONDA_PREFIX" )
58+ if not conda_prefix :
59+ return None
6660 if IS_WINDOWS :
6761 anchor_point = os .path .join (conda_prefix , "Library" )
6862 if not os .path .isdir (anchor_point ):
6963 return None
7064 else :
71- targets_include_path = glob .glob (os .path .join (conda_prefix , "targets" , "*" , "include" ))
72- if not targets_include_path :
73- return None
74- if len (targets_include_path ) != 1 :
75- # Conda does not support multiple architectures.
76- # QUESTION(PR#956): Do we want to issue a warning?
77- return None
78- anchor_point = os .path .dirname (targets_include_path [0 ])
65+ if ctk_layout :
66+ targets_include_path = glob .glob (os .path .join (conda_prefix , "targets" , "*" , "include" ))
67+ if not targets_include_path :
68+ return None
69+ if len (targets_include_path ) != 1 :
70+ # Conda does not support multiple architectures.
71+ # QUESTION(PR#956): Do we want to issue a warning?
72+ return None
73+ include_path = targets_include_path [0 ]
74+ else :
75+ include_path = os .path .join (conda_prefix , "include" )
76+ anchor_point = os .path .dirname (include_path )
7977 return _find_based_on_ctk_layout (libname , h_basename , anchor_point )
8078
8179
8280def _find_ctk_header_directory (libname : str ) -> Optional [str ]:
8381 h_basename = supported_nvidia_headers .SUPPORTED_HEADERS_CTK [libname ]
8482 candidate_dirs = supported_nvidia_headers .SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK [libname ]
8583
86- # Installed from a wheel
8784 for cdir in candidate_dirs :
88- hdr_dir : str # help mypy
89- for hdr_dir in find_sub_dirs_all_sitepackages (tuple (cdir .split ("/" ))):
90- if _joined_isfile (hdr_dir , h_basename ):
91- return hdr_dir
85+ if hdr_dir := _find_under_site_packages (cdir , h_basename ):
86+ return hdr_dir
9287
93- conda_prefix = os .environ .get ("CONDA_PREFIX" )
94- if conda_prefix : # noqa: SIM102
95- if result := _find_based_on_conda_layout (libname , h_basename , conda_prefix ):
96- return result
88+ if hdr_dir := _find_based_on_conda_layout (libname , h_basename , True ):
89+ return hdr_dir
9790
9891 cuda_home = get_cuda_home_or_path ()
9992 if cuda_home : # noqa: SIM102
@@ -132,19 +125,28 @@ def find_nvidia_header_directory(libname: str) -> Optional[str]:
132125 3. **CUDA Toolkit environment variables**
133126
134127 - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
135-
136- Notes:
137- - The ``SUPPORTED_HEADERS_CTK`` dictionary maps each supported CUDA Toolkit
138- (CTK) library to the name of its canonical header (e.g., ``"cublas" →
139- "cublas.h"``). This is used to verify that the located directory is valid.
140-
141- - The only supported non-CTK library at present is ``nvshmem``.
142128 """
143129
144- if libname == "nvshmem" :
145- return _abs_norm (_find_nvshmem_header_directory ())
146-
147130 if libname in supported_nvidia_headers .SUPPORTED_HEADERS_CTK :
148131 return _abs_norm (_find_ctk_header_directory (libname ))
149132
150- raise RuntimeError (f"UNKNOWN { libname = } " )
133+ h_basename = supported_nvidia_headers .SUPPORTED_HEADERS_NON_CTK .get (libname )
134+ if h_basename is None :
135+ raise RuntimeError (f"UNKNOWN { libname = } " )
136+
137+ candidate_dirs = supported_nvidia_headers .SUPPORTED_SITE_PACKAGE_HEADER_DIRS_NON_CTK .get (libname , [])
138+ hdr_dir : Optional [str ] # help mypy
139+ for cdir in candidate_dirs :
140+ if hdr_dir := _find_under_site_packages (cdir , h_basename ):
141+ return _abs_norm (hdr_dir )
142+
143+ if hdr_dir := _find_based_on_conda_layout (libname , h_basename , False ):
144+ return _abs_norm (hdr_dir )
145+
146+ candidate_dirs = supported_nvidia_headers .SUPPORTED_INSTALL_DIRS_NON_CTK .get (libname , [])
147+ for cdir in candidate_dirs :
148+ for hdr_dir in sorted (glob .glob (cdir ), reverse = True ):
149+ if _joined_isfile (hdr_dir , h_basename ):
150+ return _abs_norm (hdr_dir )
151+
152+ return None
0 commit comments