Skip to content

Commit 2716adf

Browse files
committed
MSVC fixes
1 parent 4c55dc5 commit 2716adf

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

torch/_inductor/codecache.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def get_path(basename: str, extension: str, specified_dir: str = ""):
312312
subdir = os.path.join(cache_dir(), specified_dir)
313313
else:
314314
subdir = os.path.join(cache_dir(), basename[1:3])
315-
path = os.path.join(subdir, f"{basename}.{extension}")
315+
path = os.path.join(subdir, f"{basename}.{extension}").replace(os.sep, "/")
316316
return basename, subdir, path
317317

318318

@@ -431,7 +431,10 @@ def cpp_compiler_search(search):
431431
)
432432
with lock:
433433
cxx = install_gcc_via_conda()
434-
subprocess.check_output([cxx, "--version"])
434+
if cxx == "cl":
435+
subprocess.check_output([cxx])
436+
else:
437+
subprocess.check_output([cxx, "--version"])
435438
return cxx
436439
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
437440
continue
@@ -504,7 +507,12 @@ class VecISA:
504507
505508
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
506509
507-
extern "C" void __avx_chk_kernel() {
510+
#ifdef _MSC_VER
511+
#define DLLEXPORT __declspec(dllexport)
512+
#else
513+
#define DLLEXPORT
514+
#endif
515+
extern "C" DLLEXPORT void __avx_chk_kernel() {
508516
auto tmp0 = at::vec::Vectorized<float>(1);
509517
auto tmp1 = tmp0.exp();
510518
tmp1.store(in_out_ptr0);
@@ -543,7 +551,7 @@ def __bool__(self):
543551
lock_dir = get_lock_dir()
544552
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
545553
with lock:
546-
output_path = input_path[:-3] + "so"
554+
output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll")
547555
build_cmd = shlex.split(
548556
cpp_compile_command(
549557
input_path, output_path, warning_all=False, vec_isa=self
@@ -647,6 +655,10 @@ def pick_vec_isa():
647655

648656

649657
def get_shared(shared=True):
658+
if sys.platform == "win32":
659+
if cpp_compiler() in ["cl", "clang", "clang-cl"]:
660+
return ""
661+
return "-shared" if shared else ""
650662
return "-shared -fPIC" if shared else ""
651663

652664

@@ -655,6 +667,8 @@ def get_warning_all_flag(warning_all=True):
655667

656668

657669
def cpp_flags():
670+
if cpp_compiler() in ["cl", "clang-cl"]:
671+
return "/std:c++17"
658672
return "-std=c++17 -Wno-unused-variable"
659673

660674

@@ -664,6 +678,8 @@ def cpp_wrapper_flags():
664678

665679
def optimization_flags():
666680
base_flags = "-O3 -ffast-math -fno-finite-math-only"
681+
if cpp_compiler() in ["cl", "clang-cl"]:
682+
base_flags = "/nologo /O2 /fp:fast"
667683
if config.is_fbcode():
668684
# FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
669685
# This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
@@ -674,6 +690,8 @@ def optimization_flags():
674690
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
675691
# Also, `-march=native` is unrecognized option on M1
676692
base_flags += " -Xclang"
693+
elif sys.platform == "win32":
694+
pass
677695
else:
678696
if platform.machine() == "ppc64le":
679697
base_flags += " -mcpu=native"
@@ -682,12 +700,15 @@ def optimization_flags():
682700

683701
# Internal cannot find libgomp.so
684702
if not config.is_fbcode():
685-
base_flags += " -fopenmp"
703+
if cpp_compiler() in ["cl", "clang-cl"]:
704+
base_flags += " /openmp"
705+
else:
706+
base_flags += " -fopenmp"
686707
return base_flags
687708

688709

689710
def use_custom_generated_macros():
690-
return "-D C10_USING_CUSTOM_GENERATED_MACROS"
711+
return "-DC10_USING_CUSTOM_GENERATED_MACROS"
691712

692713

693714
def use_fb_internal_macros():
@@ -844,6 +865,9 @@ def get_include_and_linking_paths(
844865
else:
845866
libs = ["omp"] if config.is_fbcode() else ["gomp"]
846867

868+
if sys.platform == "win32" and "gomp" in libs:
869+
libs.pop(libs.index("gomp"))
870+
847871
# third party libs
848872
if config.is_fbcode():
849873
ipaths.append(build_paths.sleef())
@@ -859,9 +883,13 @@ def get_include_and_linking_paths(
859883
# (later on, we copy the include paths from cpp_extensions into our remote dir)
860884
ipaths.append("include")
861885

862-
ipaths = " ".join(["-I" + p for p in ipaths])
863-
lpaths = " ".join(["-L" + p for p in lpaths])
864-
libs = " ".join(["-l" + p for p in libs])
886+
ipaths = " ".join([f'-I"{p}"' for p in ipaths])
887+
lpaths = " ".join([f'-L"{p}"' for p in lpaths])
888+
libs = " ".join([f'-l"{p}"' for p in libs])
889+
if sys.platform == "win32":
890+
ipaths = ipaths.replace(os.sep, "/")
891+
lpaths = lpaths.replace(os.sep, "/")
892+
libs = libs.replace(os.sep, "/")
865893
return ipaths, lpaths, libs, macros
866894

867895

@@ -892,6 +920,10 @@ def cpp_compile_command(
892920
inp_name = input
893921
out_name = output
894922
linker_paths = "" # let the compiler pick
923+
924+
out_dir = ""
925+
if cpp_compiler() in ["cl", "clang-cl"]:
926+
out_dir = "/Fe:" + os.path.dirname(out_name) + "/"
895927
return re.sub(
896928
r"[ \n]+",
897929
" ",
@@ -903,7 +935,8 @@ def cpp_compile_command(
903935
{use_custom_generated_macros()}
904936
{use_fb_internal_macros()}
905937
{use_standard_sys_dir_headers()}
906-
-o {out_name}
938+
{out_dir}
939+
{"-o " if "cl" not in cpp_compiler() else "/LDd /OUT:"}"{out_name}"
907940
""",
908941
).strip()
909942

@@ -953,7 +986,7 @@ def compile(cls, graph, source_code, cuda):
953986
lock_dir = get_lock_dir()
954987
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
955988
with lock:
956-
output_so = os.path.splitext(input_path)[0] + ".so"
989+
output_so = os.path.splitext(input_path)[0] + (".so" if sys.platform != "win32" else ".dll")
957990

958991
if not os.path.exists(output_so):
959992
cmd = shlex.split(
@@ -1011,9 +1044,16 @@ def cpp_prefix():
10111044
# everything that we compile into a folder for remote compilation.
10121045
return f'#include "{os.path.basename(filename)}"'
10131046
else:
1047+
filename = filename.replace(os.sep, "/")
10141048
return f'#include "{filename}"'
10151049

10161050

1051+
@functools.lru_cache(None)
1052+
def output_encoding():
1053+
import locale
1054+
return locale.getpreferredencoding()
1055+
1056+
10171057
# Given a path to an input cpp file and an output path,
10181058
# Attempts to compile the file, storing the output in "output_path"
10191059
def compile_file(input_path, output_path, cmd) -> None:
@@ -1045,7 +1085,7 @@ def compile_file(input_path, output_path, cmd) -> None:
10451085
else:
10461086
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
10471087
except subprocess.CalledProcessError as e:
1048-
output = e.output.decode("utf-8")
1088+
output = e.output.decode(output_encoding())
10491089
openmp_problem = "'omp.h' file not found" in output or "libomp" in output
10501090
if openmp_problem and sys.platform == "darwin":
10511091
instruction = (
@@ -1095,15 +1135,19 @@ def load(cls, source_code):
10951135
lock_dir = get_lock_dir()
10961136
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
10971137
with lock:
1098-
output_path = input_path[:-3] + "so"
1138+
output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll")
10991139
if not os.path.exists(output_path):
11001140
cmd = shlex.split(
11011141
cpp_compile_command(
11021142
input=input_path, output=output_path, vec_isa=picked_vec_isa
11031143
)
11041144
)
11051145
compile_file(input_path, output_path, cmd)
1106-
cls.cache[key] = cls._load_library(output_path)
1146+
if sys.platform == "win32":
1147+
#cls.cache[key] = cls._load_library(os.path.join(".", os.path.basename(output_path)))
1148+
cls.cache[key] = cls._load_library(output_path)
1149+
else:
1150+
cls.cache[key] = cls._load_library(output_path)
11071151
cls.cache[key].key = key
11081152

11091153
return cls.cache[key]
@@ -1128,7 +1172,7 @@ def load_by_key_path(cls, key, path, linemap=()):
11281172
if key not in cls.cache:
11291173
with open(path) as f:
11301174
try:
1131-
code = compile(f.read(), path, "exec")
1175+
code = compile(f.read(), path.replace(os.sep, "/"), "exec")
11321176
except Exception as e:
11331177
raise RuntimeError(
11341178
f"Failed to import {path}\n{type(e).__name__}: {e}"
@@ -1183,7 +1227,7 @@ def load(cls, source_code, func_name, key, cuda):
11831227
if not os.path.exists(cpp_wrapper_dir):
11841228
os.makedirs(cpp_wrapper_dir)
11851229

1186-
ext = "so"
1230+
ext = "so" if sys.platform != "win32" else "dll"
11871231
filepath = os.path.join(cpp_wrapper_dir, f"{name}.{ext}")
11881232
log.debug("Cpp wrapper code path %s", filepath)
11891233

torch/_inductor/codegen/cpp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2858,7 +2858,12 @@ def codegen_define_and_call(self, wrapper):
28582858
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
28592859
code.writeline(codecache.cpp_prefix())
28602860

2861-
code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
2861+
code.writeline('#ifdef _MSC_VER')
2862+
code.writeline(' #define DLLEXPORT __declspec(dllexport)')
2863+
code.writeline('#else')
2864+
code.writeline(' #define DLLEXPORT')
2865+
code.writeline('#endif')
2866+
code.writeline(f'extern "C" DLLEXPORT void {kernel_decl_name}({arg_defs})')
28622867
with code.indent():
28632868
if enable_kernel_profile:
28642869
graph_id = V.graph.graph_id

0 commit comments

Comments
 (0)