@@ -312,7 +312,7 @@ def get_path(basename: str, extension: str, specified_dir: str = ""):
312
312
subdir = os .path .join (cache_dir (), specified_dir )
313
313
else :
314
314
subdir = os .path .join (cache_dir (), basename [1 :3 ])
315
- path = os .path .join (subdir , f"{ basename } .{ extension } " )
315
+ path = os .path .join (subdir , f"{ basename } .{ extension } " ). replace ( os . sep , "/" )
316
316
return basename , subdir , path
317
317
318
318
@@ -431,7 +431,10 @@ def cpp_compiler_search(search):
431
431
)
432
432
with lock :
433
433
cxx = install_gcc_via_conda ()
434
- subprocess .check_output ([cxx , "--version" ])
434
+ if cxx == "cl" :
435
+ subprocess .check_output ([cxx ])
436
+ else :
437
+ subprocess .check_output ([cxx , "--version" ])
435
438
return cxx
436
439
except (subprocess .SubprocessError , FileNotFoundError , ImportError ):
437
440
continue
@@ -504,7 +507,12 @@ class VecISA:
504
507
505
508
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
506
509
507
- extern "C" void __avx_chk_kernel() {
510
+ #ifdef _MSC_VER
511
+ #define DLLEXPORT __declspec(dllexport)
512
+ #else
513
+ #define DLLEXPORT
514
+ #endif
515
+ extern "C" DLLEXPORT void __avx_chk_kernel() {
508
516
auto tmp0 = at::vec::Vectorized<float>(1);
509
517
auto tmp1 = tmp0.exp();
510
518
tmp1.store(in_out_ptr0);
@@ -543,7 +551,7 @@ def __bool__(self):
543
551
lock_dir = get_lock_dir ()
544
552
lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
545
553
with lock :
546
- output_path = input_path [:- 3 ] + "so"
554
+ output_path = input_path [:- 3 ] + ( "so" if sys . platform != "win32" else "dll" )
547
555
build_cmd = shlex .split (
548
556
cpp_compile_command (
549
557
input_path , output_path , warning_all = False , vec_isa = self
@@ -647,6 +655,10 @@ def pick_vec_isa():
647
655
648
656
649
657
def get_shared (shared = True ):
658
+ if sys .platform == "win32" :
659
+ if cpp_compiler () in ["cl" , "clang" , "clang-cl" ]:
660
+ return ""
661
+ return "-shared" if shared else ""
650
662
return "-shared -fPIC" if shared else ""
651
663
652
664
@@ -655,6 +667,8 @@ def get_warning_all_flag(warning_all=True):
655
667
656
668
657
669
def cpp_flags ():
670
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
671
+ return "/std:c++17"
658
672
return "-std=c++17 -Wno-unused-variable"
659
673
660
674
@@ -664,6 +678,8 @@ def cpp_wrapper_flags():
664
678
665
679
def optimization_flags ():
666
680
base_flags = "-O3 -ffast-math -fno-finite-math-only"
681
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
682
+ base_flags = "/nologo /O2 /fp:fast"
667
683
if config .is_fbcode ():
668
684
# FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
669
685
# This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
@@ -674,6 +690,8 @@ def optimization_flags():
674
690
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
675
691
# Also, `-march=native` is unrecognized option on M1
676
692
base_flags += " -Xclang"
693
+ elif sys .platform == "win32" :
694
+ pass
677
695
else :
678
696
if platform .machine () == "ppc64le" :
679
697
base_flags += " -mcpu=native"
@@ -682,12 +700,15 @@ def optimization_flags():
682
700
683
701
# Internal cannot find libgomp.so
684
702
if not config .is_fbcode ():
685
- base_flags += " -fopenmp"
703
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
704
+ base_flags += " /openmp"
705
+ else :
706
+ base_flags += " -fopenmp"
686
707
return base_flags
687
708
688
709
689
710
def use_custom_generated_macros ():
690
- return "-D C10_USING_CUSTOM_GENERATED_MACROS "
711
+ return "-DC10_USING_CUSTOM_GENERATED_MACROS "
691
712
692
713
693
714
def use_fb_internal_macros ():
@@ -844,6 +865,9 @@ def get_include_and_linking_paths(
844
865
else :
845
866
libs = ["omp" ] if config .is_fbcode () else ["gomp" ]
846
867
868
+ if sys .platform == "win32" and "gomp" in libs :
869
+ libs .pop (libs .index ("gomp" ))
870
+
847
871
# third party libs
848
872
if config .is_fbcode ():
849
873
ipaths .append (build_paths .sleef ())
@@ -859,9 +883,13 @@ def get_include_and_linking_paths(
859
883
# (later on, we copy the include paths from cpp_extensions into our remote dir)
860
884
ipaths .append ("include" )
861
885
862
- ipaths = " " .join (["-I" + p for p in ipaths ])
863
- lpaths = " " .join (["-L" + p for p in lpaths ])
864
- libs = " " .join (["-l" + p for p in libs ])
886
+ ipaths = " " .join ([f'-I"{ p } "' for p in ipaths ])
887
+ lpaths = " " .join ([f'-L"{ p } "' for p in lpaths ])
888
+ libs = " " .join ([f'-l"{ p } "' for p in libs ])
889
+ if sys .platform == "win32" :
890
+ ipaths = ipaths .replace (os .sep , "/" )
891
+ lpaths = lpaths .replace (os .sep , "/" )
892
+ libs = libs .replace (os .sep , "/" )
865
893
return ipaths , lpaths , libs , macros
866
894
867
895
@@ -892,6 +920,10 @@ def cpp_compile_command(
892
920
inp_name = input
893
921
out_name = output
894
922
linker_paths = "" # let the compiler pick
923
+
924
+ out_dir = ""
925
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
926
+ out_dir = "/Fe:" + os .path .dirname (out_name ) + "/"
895
927
return re .sub (
896
928
r"[ \n]+" ,
897
929
" " ,
@@ -903,7 +935,8 @@ def cpp_compile_command(
903
935
{ use_custom_generated_macros ()}
904
936
{ use_fb_internal_macros ()}
905
937
{ use_standard_sys_dir_headers ()}
906
- -o { out_name }
938
+ { out_dir }
939
+ { "-o " if "cl" not in cpp_compiler () else "/LDd /OUT:" } "{ out_name } "
907
940
""" ,
908
941
).strip ()
909
942
@@ -953,7 +986,7 @@ def compile(cls, graph, source_code, cuda):
953
986
lock_dir = get_lock_dir ()
954
987
lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
955
988
with lock :
956
- output_so = os .path .splitext (input_path )[0 ] + ".so"
989
+ output_so = os .path .splitext (input_path )[0 ] + ( ".so" if sys . platform != "win32" else ".dll" )
957
990
958
991
if not os .path .exists (output_so ):
959
992
cmd = shlex .split (
@@ -1011,9 +1044,16 @@ def cpp_prefix():
1011
1044
# everything that we compile into a folder for remote compilation.
1012
1045
return f'#include "{ os .path .basename (filename )} "'
1013
1046
else :
1047
+ filename = filename .replace (os .sep , "/" )
1014
1048
return f'#include "{ filename } "'
1015
1049
1016
1050
1051
+ @functools .lru_cache (None )
1052
+ def output_encoding ():
1053
+ import locale
1054
+ return locale .getpreferredencoding ()
1055
+
1056
+
1017
1057
# Given a path to an input cpp file and an output path,
1018
1058
# Attempts to compile the file, storing the output in "output_path"
1019
1059
def compile_file (input_path , output_path , cmd ) -> None :
@@ -1045,7 +1085,7 @@ def compile_file(input_path, output_path, cmd) -> None:
1045
1085
else :
1046
1086
subprocess .check_output (cmd , stderr = subprocess .STDOUT )
1047
1087
except subprocess .CalledProcessError as e :
1048
- output = e .output .decode ("utf-8" )
1088
+ output = e .output .decode (output_encoding () )
1049
1089
openmp_problem = "'omp.h' file not found" in output or "libomp" in output
1050
1090
if openmp_problem and sys .platform == "darwin" :
1051
1091
instruction = (
@@ -1095,15 +1135,19 @@ def load(cls, source_code):
1095
1135
lock_dir = get_lock_dir ()
1096
1136
lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
1097
1137
with lock :
1098
- output_path = input_path [:- 3 ] + "so"
1138
+ output_path = input_path [:- 3 ] + ( "so" if sys . platform != "win32" else "dll" )
1099
1139
if not os .path .exists (output_path ):
1100
1140
cmd = shlex .split (
1101
1141
cpp_compile_command (
1102
1142
input = input_path , output = output_path , vec_isa = picked_vec_isa
1103
1143
)
1104
1144
)
1105
1145
compile_file (input_path , output_path , cmd )
1106
- cls .cache [key ] = cls ._load_library (output_path )
1146
+ if sys .platform == "win32" :
1147
+ #cls.cache[key] = cls._load_library(os.path.join(".", os.path.basename(output_path)))
1148
+ cls .cache [key ] = cls ._load_library (output_path )
1149
+ else :
1150
+ cls .cache [key ] = cls ._load_library (output_path )
1107
1151
cls .cache [key ].key = key
1108
1152
1109
1153
return cls .cache [key ]
@@ -1128,7 +1172,7 @@ def load_by_key_path(cls, key, path, linemap=()):
1128
1172
if key not in cls .cache :
1129
1173
with open (path ) as f :
1130
1174
try :
1131
- code = compile (f .read (), path , "exec" )
1175
+ code = compile (f .read (), path . replace ( os . sep , "/" ) , "exec" )
1132
1176
except Exception as e :
1133
1177
raise RuntimeError (
1134
1178
f"Failed to import { path } \n { type (e ).__name__ } : { e } "
@@ -1183,7 +1227,7 @@ def load(cls, source_code, func_name, key, cuda):
1183
1227
if not os .path .exists (cpp_wrapper_dir ):
1184
1228
os .makedirs (cpp_wrapper_dir )
1185
1229
1186
- ext = "so"
1230
+ ext = "so" if sys . platform != "win32" else "dll"
1187
1231
filepath = os .path .join (cpp_wrapper_dir , f"{ name } .{ ext } " )
1188
1232
log .debug ("Cpp wrapper code path %s" , filepath )
1189
1233
0 commit comments