@@ -342,7 +342,7 @@ def get_path(
342
342
subdir = os .path .join (cache_dir (), specified_dir )
343
343
else :
344
344
subdir = os .path .join (cache_dir (), basename [1 :3 ])
345
- path = os .path .join (subdir , f"{ basename } .{ extension } " )
345
+ path = os .path .join (subdir , f"{ basename } .{ extension } " ). replace ( os . sep , "/" )
346
346
return basename , subdir , path
347
347
348
348
@@ -921,7 +921,10 @@ def cpp_compiler_search(search: str) -> str:
921
921
)
922
922
with lock :
923
923
cxx = install_gcc_via_conda ()
924
- subprocess .check_output ([cxx , "--version" ])
924
+ if cxx == "cl" :
925
+ subprocess .check_output ([cxx ])
926
+ else :
927
+ subprocess .check_output ([cxx , "--version" ])
925
928
return cxx
926
929
except (subprocess .SubprocessError , FileNotFoundError , ImportError ):
927
930
continue
@@ -998,7 +1001,12 @@ class VecISA:
998
1001
999
1002
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
1000
1003
1001
- extern "C" void __avx_chk_kernel() {
1004
+ #ifdef _MSC_VER
1005
+ #define DLLEXPORT __declspec(dllexport)
1006
+ #else
1007
+ #define DLLEXPORT
1008
+ #endif
1009
+ extern "C" DLLEXPORT void __avx_chk_kernel() {
1002
1010
auto tmp0 = at::vec::Vectorized<float>(1);
1003
1011
auto tmp1 = tmp0.exp();
1004
1012
tmp1.store(in_out_ptr0);
@@ -1040,7 +1048,7 @@ def __bool__(self) -> bool:
1040
1048
lock_dir = get_lock_dir ()
1041
1049
lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
1042
1050
with lock :
1043
- output_path = input_path [:- 3 ] + "so"
1051
+ output_path = input_path [:- 3 ] + ( "so" if sys . platform != "win32" else "dll" )
1044
1052
build_cmd = shlex .split (
1045
1053
cpp_compile_command (
1046
1054
input_path , output_path , warning_all = False , vec_isa = self
@@ -1167,6 +1175,10 @@ def get_compile_only(compile_only: bool = True) -> str:
1167
1175
1168
1176
1169
1177
def get_shared (shared : bool = True ) -> str :
1178
+ if sys .platform == "win32" :
1179
+ if cpp_compiler () in ["cl" , "clang" , "clang-cl" ]:
1180
+ return ""
1181
+ return "-shared" if shared else ""
1170
1182
return "-shared -fPIC" if shared else ""
1171
1183
1172
1184
@@ -1180,6 +1192,8 @@ def get_glibcxx_abi_build_flags() -> str:
1180
1192
1181
1193
def cpp_flags () -> str :
1182
1194
flags = ["-std=c++17" , "-Wno-unused-variable" , "-Wno-unknown-pragmas" ]
1195
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
1196
+ return "/std:c++17"
1183
1197
if is_clang ():
1184
1198
flags .append ("-Werror=ignored-optimization-argument" )
1185
1199
return " " .join (flags )
@@ -1192,6 +1206,8 @@ def cpp_wrapper_flags() -> str:
1192
1206
def optimization_flags () -> str :
1193
1207
base_flags = "-O0 -g" if config .aot_inductor .debug_compile else "-O3 -DNDEBUG"
1194
1208
base_flags += " -ffast-math -fno-finite-math-only"
1209
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
1210
+ base_flags = "/nologo /O2 /fp:fast"
1195
1211
if not config .cpp .enable_unsafe_math_opt_flag :
1196
1212
base_flags += " -fno-unsafe-math-optimizations"
1197
1213
@@ -1205,6 +1221,8 @@ def optimization_flags() -> str:
1205
1221
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
1206
1222
# Also, `-march=native` is unrecognized option on M1
1207
1223
base_flags += " -Xclang"
1224
+ elif sys .platform == "win32" :
1225
+ pass
1208
1226
else :
1209
1227
if platform .machine () == "ppc64le" :
1210
1228
base_flags += " -mcpu=native"
@@ -1213,12 +1231,15 @@ def optimization_flags() -> str:
1213
1231
1214
1232
# Internal cannot find libgomp.so
1215
1233
if not config .is_fbcode ():
1216
- base_flags += " -fopenmp"
1234
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
1235
+ base_flags += " /openmp"
1236
+ else :
1237
+ base_flags += " -fopenmp"
1217
1238
return base_flags
1218
1239
1219
1240
1220
1241
def use_custom_generated_macros () -> str :
1221
- return "-D C10_USING_CUSTOM_GENERATED_MACROS "
1242
+ return "-DC10_USING_CUSTOM_GENERATED_MACROS "
1222
1243
1223
1244
1224
1245
def use_fb_internal_macros () -> str :
@@ -1406,6 +1427,8 @@ def get_include_and_linking_paths(
1406
1427
# and raise error together with instructions at compilation error later
1407
1428
else :
1408
1429
libs = ["omp" ] if config .is_fbcode () else ["gomp" ]
1430
+ if sys .platform == "win32" and "gomp" in libs :
1431
+ libs .pop (libs .index ("gomp" ))
1409
1432
1410
1433
# Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
1411
1434
if not config .aot_inductor .abi_compatible :
@@ -1434,6 +1457,11 @@ def get_include_and_linking_paths(
1434
1457
1435
1458
lpaths_str = " " .join (["-L" + p for p in lpaths ])
1436
1459
libs_str = " " .join (static_link_libs + ["-l" + p for p in libs ])
1460
+ if sys .platform == "win32" :
1461
+ ipaths = [p .replace (os .sep , "/" ) for p in ipaths ]
1462
+ lpaths_str = lpaths_str .replace (os .sep , "/" )
1463
+ libs_str = libs_str .replace (os .sep , "/" )
1464
+
1437
1465
return ipaths , lpaths_str , libs_str , macros , build_arch_flags
1438
1466
1439
1467
@@ -1454,7 +1482,7 @@ def cpp_compile_command(
1454
1482
)
1455
1483
if isinstance (input , str ):
1456
1484
input = [input ]
1457
- ipaths_str = " " .join ([" -I" + p for p in ipaths ])
1485
+ ipaths_str = " " .join ([f' -I"{ p } "' for p in ipaths ])
1458
1486
clang_flags = ""
1459
1487
if config .is_fbcode ():
1460
1488
if aot_mode and not use_absolute_path :
@@ -1475,6 +1503,10 @@ def cpp_compile_command(
1475
1503
out_name = output
1476
1504
linker_paths = "" # let the compiler pick
1477
1505
inp_name_str = " " .join (inp_name )
1506
+
1507
+ out_dir = ""
1508
+ if cpp_compiler () in ["cl" , "clang-cl" ]:
1509
+ out_dir = "/Fe:" + os .path .dirname (out_name ) + "/"
1478
1510
return re .sub (
1479
1511
r"[ \n]+" ,
1480
1512
" " ,
@@ -1489,7 +1521,8 @@ def cpp_compile_command(
1489
1521
{ use_fb_internal_macros ()}
1490
1522
{ use_standard_sys_dir_headers ()}
1491
1523
{ get_compile_only (compile_only )}
1492
- -o { out_name }
1524
+ { out_dir }
1525
+ { "-o " if "cl" not in cpp_compiler () else "/LDd /OUT:" } "{ out_name } "
1493
1526
""" ,
1494
1527
).strip ()
1495
1528
@@ -1600,7 +1633,7 @@ def compile(
1600
1633
output_so = (
1601
1634
config .aot_inductor .output_path
1602
1635
if specified_so_name
1603
- else os .path .splitext (input_path )[0 ] + ".so"
1636
+ else os .path .splitext (input_path )[0 ] + ( ".so" if sys . platform != "win32" else ".dll" )
1604
1637
)
1605
1638
1606
1639
if not os .path .exists (output_so ):
@@ -1741,9 +1774,16 @@ def cpp_prefix() -> str:
1741
1774
# everything that we compile into a folder for remote compilation.
1742
1775
return f'#include "{ os .path .basename (filename )} "'
1743
1776
else :
1777
+ filename = filename .replace (os .sep , "/" )
1744
1778
return f'#include "{ filename } "'
1745
1779
1746
1780
1781
+ @functools .lru_cache (None )
1782
+ def output_encoding ():
1783
+ import locale
1784
+ return locale .getpreferredencoding ()
1785
+
1786
+
1747
1787
# Given a path to an input cpp file and an output path,
1748
1788
# Attempts to compile the file, storing the output in "output_path"
1749
1789
def compile_file (
@@ -1781,7 +1821,7 @@ def compile_file(
1781
1821
else :
1782
1822
subprocess .check_output (cmd , stderr = subprocess .STDOUT )
1783
1823
except subprocess .CalledProcessError as e :
1784
- output = e .output .decode ("utf-8" )
1824
+ output = e .output .decode (output_encoding () )
1785
1825
openmp_problem = "'omp.h' file not found" in output or "libomp" in output
1786
1826
if openmp_problem and sys .platform == "darwin" :
1787
1827
instruction = (
@@ -1834,7 +1874,7 @@ def load(cls, source_code: str) -> CDLL:
1834
1874
lock_dir = get_lock_dir ()
1835
1875
lock = FileLock (os .path .join (lock_dir , key + ".lock" ), timeout = LOCK_TIMEOUT )
1836
1876
with lock :
1837
- output_path = input_path [:- 3 ] + "so"
1877
+ output_path = input_path [:- 3 ] + ( "so" if sys . platform != "win32" else "dll" )
1838
1878
if not os .path .exists (output_path ):
1839
1879
cmd = shlex .split (
1840
1880
cpp_compile_command (
@@ -1881,7 +1921,7 @@ def load_by_key_path(
1881
1921
if key not in cls .cache :
1882
1922
with open (path ) as f :
1883
1923
try :
1884
- code = compile (f .read (), path , "exec" )
1924
+ code = compile (f .read (), path . replace ( os . sep , "/" ) , "exec" )
1885
1925
except Exception as e :
1886
1926
raise RuntimeError (
1887
1927
f"Failed to import { path } \n { type (e ).__name__ } : { e } "
@@ -1941,7 +1981,7 @@ def load(cls, source_code: str, func_name: str, key: str, cuda: bool) -> CDLL:
1941
1981
cpp_wrapper_dir = cpp_wrapper_cache_dir (name )
1942
1982
os .makedirs (cpp_wrapper_dir , exist_ok = True )
1943
1983
1944
- ext = "so"
1984
+ ext = "so" if sys . platform != "win32" else "dll"
1945
1985
filepath = os .path .join (cpp_wrapper_dir , f"{ name } .{ ext } " )
1946
1986
log .debug ("Cpp wrapper code path %s" , filepath )
1947
1987
0 commit comments