@@ -269,15 +269,17 @@ def run(self):
269269 # First, run the standard build_ext command to compile the extensions
270270 super ().run ()
271271
272- # copy vllm/vllm_flash_attn/*.py from self.build_lib to current
272+ # copy vllm/vllm_flash_attn/**/* .py from self.build_lib to current
273273 # directory so that they can be included in the editable build
274274 import glob
275- files = glob .glob (
276- os .path .join (self .build_lib , "vllm" , "vllm_flash_attn" , "*.py" ))
275+ files = glob .glob (os .path .join (self .build_lib , "vllm" ,
276+ "vllm_flash_attn" , "**" , "*.py" ),
277+ recursive = True )
277278 for file in files :
278279 dst_file = os .path .join ("vllm/vllm_flash_attn" ,
279- os . path . basename ( file ) )
280+ file . split ( "vllm/vllm_flash_attn/" )[ - 1 ] )
280281 print (f"Copying { file } to { dst_file } " )
282+ os .makedirs (os .path .dirname (dst_file ), exist_ok = True )
281283 self .copy_file (file , dst_file )
282284
283285
@@ -377,12 +379,22 @@ def run(self) -> None:
377379 "vllm/_flashmla_C.abi3.so" ,
378380 "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so" ,
379381 "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so" ,
380- "vllm/vllm_flash_attn/flash_attn_interface.py" ,
381382 "vllm/cumem_allocator.abi3.so" ,
382383 # "vllm/_version.py", # not available in nightly wheels yet
383384 ]
384- file_members = filter (lambda x : x .filename in files_to_copy ,
385- wheel .filelist )
385+
386+ file_members = list (
387+ filter (lambda x : x .filename in files_to_copy , wheel .filelist ))
388+
389+ # vllm_flash_attn python code:
390+ # Regex from
391+ # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
392+ import re
393+ compiled_regex = re .compile (
394+ r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" )
395+ file_members += list (
396+ filter (lambda x : compiled_regex .match (x .filename ),
397+ wheel .filelist ))
386398
387399 for file in file_members :
388400 print (f"Extracting and including { file .filename } "
0 commit comments