Skip to content

Commit 7cd21b1

Browse files
committed
MSVC fixes
1 parent 94db657 commit 7cd21b1

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

torch/_inductor/codecache.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def get_path(
342342
subdir = os.path.join(cache_dir(), specified_dir)
343343
else:
344344
subdir = os.path.join(cache_dir(), basename[1:3])
345-
path = os.path.join(subdir, f"{basename}.{extension}")
345+
path = os.path.join(subdir, f"{basename}.{extension}").replace(os.sep, "/")
346346
return basename, subdir, path
347347

348348

@@ -921,7 +921,10 @@ def cpp_compiler_search(search: str) -> str:
921921
)
922922
with lock:
923923
cxx = install_gcc_via_conda()
924-
subprocess.check_output([cxx, "--version"])
924+
if cxx == "cl":
925+
subprocess.check_output([cxx])
926+
else:
927+
subprocess.check_output([cxx, "--version"])
925928
return cxx
926929
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
927930
continue
@@ -998,7 +1001,12 @@ class VecISA:
9981001
9991002
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
10001003
1001-
extern "C" void __avx_chk_kernel() {
1004+
#ifdef _MSC_VER
1005+
#define DLLEXPORT __declspec(dllexport)
1006+
#else
1007+
#define DLLEXPORT
1008+
#endif
1009+
extern "C" DLLEXPORT void __avx_chk_kernel() {
10021010
auto tmp0 = at::vec::Vectorized<float>(1);
10031011
auto tmp1 = tmp0.exp();
10041012
tmp1.store(in_out_ptr0);
@@ -1040,7 +1048,7 @@ def __bool__(self) -> bool:
10401048
lock_dir = get_lock_dir()
10411049
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
10421050
with lock:
1043-
output_path = input_path[:-3] + "so"
1051+
output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll")
10441052
build_cmd = shlex.split(
10451053
cpp_compile_command(
10461054
input_path, output_path, warning_all=False, vec_isa=self
@@ -1167,6 +1175,10 @@ def get_compile_only(compile_only: bool = True) -> str:
11671175

11681176

11691177
def get_shared(shared: bool = True) -> str:
1178+
if sys.platform == "win32":
1179+
if cpp_compiler() in ["cl", "clang", "clang-cl"]:
1180+
return ""
1181+
return "-shared" if shared else ""
11701182
return "-shared -fPIC" if shared else ""
11711183

11721184

@@ -1180,6 +1192,8 @@ def get_glibcxx_abi_build_flags() -> str:
11801192

11811193
def cpp_flags() -> str:
11821194
flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"]
1195+
if cpp_compiler() in ["cl", "clang-cl"]:
1196+
return "/std:c++17"
11831197
if is_clang():
11841198
flags.append("-Werror=ignored-optimization-argument")
11851199
return " ".join(flags)
@@ -1192,6 +1206,8 @@ def cpp_wrapper_flags() -> str:
11921206
def optimization_flags() -> str:
11931207
base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG"
11941208
base_flags += " -ffast-math -fno-finite-math-only"
1209+
if cpp_compiler() in ["cl", "clang-cl"]:
1210+
base_flags = "/nologo /O2 /fp:fast"
11951211
if not config.cpp.enable_unsafe_math_opt_flag:
11961212
base_flags += " -fno-unsafe-math-optimizations"
11971213

@@ -1205,6 +1221,8 @@ def optimization_flags() -> str:
12051221
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
12061222
# Also, `-march=native` is unrecognized option on M1
12071223
base_flags += " -Xclang"
1224+
elif sys.platform == "win32":
1225+
pass
12081226
else:
12091227
if platform.machine() == "ppc64le":
12101228
base_flags += " -mcpu=native"
@@ -1213,12 +1231,15 @@ def optimization_flags() -> str:
12131231

12141232
# Internal cannot find libgomp.so
12151233
if not config.is_fbcode():
1216-
base_flags += " -fopenmp"
1234+
if cpp_compiler() in ["cl", "clang-cl"]:
1235+
base_flags += " /openmp"
1236+
else:
1237+
base_flags += " -fopenmp"
12171238
return base_flags
12181239

12191240

12201241
def use_custom_generated_macros() -> str:
1221-
return "-D C10_USING_CUSTOM_GENERATED_MACROS"
1242+
return "-DC10_USING_CUSTOM_GENERATED_MACROS"
12221243

12231244

12241245
def use_fb_internal_macros() -> str:
@@ -1406,6 +1427,8 @@ def get_include_and_linking_paths(
14061427
# and raise error together with instructions at compilation error later
14071428
else:
14081429
libs = ["omp"] if config.is_fbcode() else ["gomp"]
1430+
if sys.platform == "win32" and "gomp" in libs:
1431+
libs.pop(libs.index("gomp"))
14091432

14101433
# Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
14111434
if not config.aot_inductor.abi_compatible:
@@ -1434,6 +1457,11 @@ def get_include_and_linking_paths(
14341457

14351458
lpaths_str = " ".join(["-L" + p for p in lpaths])
14361459
libs_str = " ".join(static_link_libs + ["-l" + p for p in libs])
1460+
if sys.platform == "win32":
1461+
ipaths = [p.replace(os.sep, "/") for p in ipaths]
1462+
lpaths_str = lpaths_str.replace(os.sep, "/")
1463+
libs_str = libs_str.replace(os.sep, "/")
1464+
14371465
return ipaths, lpaths_str, libs_str, macros, build_arch_flags
14381466

14391467

@@ -1454,7 +1482,7 @@ def cpp_compile_command(
14541482
)
14551483
if isinstance(input, str):
14561484
input = [input]
1457-
ipaths_str = " ".join(["-I" + p for p in ipaths])
1485+
ipaths_str = " ".join([f'-I"{p}"' for p in ipaths])
14581486
clang_flags = ""
14591487
if config.is_fbcode():
14601488
if aot_mode and not use_absolute_path:
@@ -1475,6 +1503,10 @@ def cpp_compile_command(
14751503
out_name = output
14761504
linker_paths = "" # let the compiler pick
14771505
inp_name_str = " ".join(inp_name)
1506+
1507+
out_dir = ""
1508+
if cpp_compiler() in ["cl", "clang-cl"]:
1509+
out_dir = "/Fe:" + os.path.dirname(out_name) + "/"
14781510
return re.sub(
14791511
r"[ \n]+",
14801512
" ",
@@ -1489,7 +1521,8 @@ def cpp_compile_command(
14891521
{use_fb_internal_macros()}
14901522
{use_standard_sys_dir_headers()}
14911523
{get_compile_only(compile_only)}
1492-
-o {out_name}
1524+
{out_dir}
1525+
{"-o " if "cl" not in cpp_compiler() else "/LDd /OUT:"}"{out_name}"
14931526
""",
14941527
).strip()
14951528

@@ -1600,7 +1633,7 @@ def compile(
16001633
output_so = (
16011634
config.aot_inductor.output_path
16021635
if specified_so_name
1603-
else os.path.splitext(input_path)[0] + ".so"
1636+
else os.path.splitext(input_path)[0] + (".so" if sys.platform != "win32" else ".dll")
16041637
)
16051638

16061639
if not os.path.exists(output_so):
@@ -1741,9 +1774,16 @@ def cpp_prefix() -> str:
17411774
# everything that we compile into a folder for remote compilation.
17421775
return f'#include "{os.path.basename(filename)}"'
17431776
else:
1777+
filename = filename.replace(os.sep, "/")
17441778
return f'#include "{filename}"'
17451779

17461780

1781+
@functools.lru_cache(None)
1782+
def output_encoding():
1783+
import locale
1784+
return locale.getpreferredencoding()
1785+
1786+
17471787
# Given a path to an input cpp file and an output path,
17481788
# Attempts to compile the file, storing the output in "output_path"
17491789
def compile_file(
@@ -1781,7 +1821,7 @@ def compile_file(
17811821
else:
17821822
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
17831823
except subprocess.CalledProcessError as e:
1784-
output = e.output.decode("utf-8")
1824+
output = e.output.decode(output_encoding())
17851825
openmp_problem = "'omp.h' file not found" in output or "libomp" in output
17861826
if openmp_problem and sys.platform == "darwin":
17871827
instruction = (
@@ -1834,7 +1874,7 @@ def load(cls, source_code: str) -> CDLL:
18341874
lock_dir = get_lock_dir()
18351875
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
18361876
with lock:
1837-
output_path = input_path[:-3] + "so"
1877+
output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll")
18381878
if not os.path.exists(output_path):
18391879
cmd = shlex.split(
18401880
cpp_compile_command(
@@ -1881,7 +1921,7 @@ def load_by_key_path(
18811921
if key not in cls.cache:
18821922
with open(path) as f:
18831923
try:
1884-
code = compile(f.read(), path, "exec")
1924+
code = compile(f.read(), path.replace(os.sep, "/"), "exec")
18851925
except Exception as e:
18861926
raise RuntimeError(
18871927
f"Failed to import {path}\n{type(e).__name__}: {e}"
@@ -1941,7 +1981,7 @@ def load(cls, source_code: str, func_name: str, key: str, cuda: bool) -> CDLL:
19411981
cpp_wrapper_dir = cpp_wrapper_cache_dir(name)
19421982
os.makedirs(cpp_wrapper_dir, exist_ok=True)
19431983

1944-
ext = "so"
1984+
ext = "so" if sys.platform != "win32" else "dll"
19451985
filepath = os.path.join(cpp_wrapper_dir, f"{name}.{ext}")
19461986
log.debug("Cpp wrapper code path %s", filepath)
19471987

torch/_inductor/codegen/cpp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3282,7 +3282,12 @@ def codegen_define_and_call(self, wrapper):
32823282
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
32833283
code.writeline(codecache.cpp_prefix())
32843284

3285-
code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
3285+
code.writeline('#ifdef _MSC_VER')
3286+
code.writeline(' #define DLLEXPORT __declspec(dllexport)')
3287+
code.writeline('#else')
3288+
code.writeline(' #define DLLEXPORT')
3289+
code.writeline('#endif')
3290+
code.writeline(f'extern "C" DLLEXPORT void {kernel_decl_name}({arg_defs})')
32863291
with code.indent():
32873292
if enable_kernel_profile:
32883293
graph_id = V.graph.graph_id

0 commit comments

Comments
 (0)