From 7cd21b121d0421c3e45e16b03468713dcbaf6539 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 23 Dec 2023 14:31:57 +0900 Subject: [PATCH 1/3] MSVC fixes --- torch/_inductor/codecache.py | 66 +++++++++++++++++++++++++++------- torch/_inductor/codegen/cpp.py | 7 +++- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 8b66a5328ce74c..011e7a646dc0f3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -342,7 +342,7 @@ def get_path( subdir = os.path.join(cache_dir(), specified_dir) else: subdir = os.path.join(cache_dir(), basename[1:3]) - path = os.path.join(subdir, f"{basename}.{extension}") + path = os.path.join(subdir, f"{basename}.{extension}").replace(os.sep, "/") return basename, subdir, path @@ -921,7 +921,10 @@ def cpp_compiler_search(search: str) -> str: ) with lock: cxx = install_gcc_via_conda() - subprocess.check_output([cxx, "--version"]) + if cxx == "cl": + subprocess.check_output([cxx]) + else: + subprocess.check_output([cxx, "--version"]) return cxx except (subprocess.SubprocessError, FileNotFoundError, ImportError): continue @@ -998,7 +1001,12 @@ class VecISA: __attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; -extern "C" void __avx_chk_kernel() { +#ifdef _MSC_VER +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT +#endif +extern "C" DLLEXPORT void __avx_chk_kernel() { auto tmp0 = at::vec::Vectorized(1); auto tmp1 = tmp0.exp(); tmp1.store(in_out_ptr0); @@ -1040,7 +1048,7 @@ def __bool__(self) -> bool: lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: - output_path = input_path[:-3] + "so" + output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll") build_cmd = shlex.split( cpp_compile_command( input_path, output_path, warning_all=False, vec_isa=self @@ -1167,6 +1175,10 @@ def get_compile_only(compile_only: bool = True) -> str: def get_shared(shared: bool = True) -> str: + if sys.platform == "win32": + if cpp_compiler() in ["cl", "clang", "clang-cl"]: + return "" + return "-shared" if shared else "" return "-shared -fPIC" if shared else "" @@ -1180,6 +1192,8 @@ def get_glibcxx_abi_build_flags() -> str: def cpp_flags() -> str: flags = ["-std=c++17", "-Wno-unused-variable", "-Wno-unknown-pragmas"] + if cpp_compiler() in ["cl", "clang-cl"]: + return "/std:c++17" if is_clang(): flags.append("-Werror=ignored-optimization-argument") return " ".join(flags) @@ -1192,6 +1206,8 @@ def cpp_wrapper_flags() -> str: def optimization_flags() -> str: base_flags = "-O0 -g" if config.aot_inductor.debug_compile else "-O3 -DNDEBUG" base_flags += " -ffast-math -fno-finite-math-only" + if cpp_compiler() in ["cl", "clang-cl"]: + base_flags = "/nologo /O2 /fp:fast" if not config.cpp.enable_unsafe_math_opt_flag: base_flags += " -fno-unsafe-math-optimizations" @@ -1205,6 +1221,8 @@ def optimization_flags() -> str: # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` # Also, `-march=native` is unrecognized option on M1 base_flags += " -Xclang" + elif sys.platform == "win32": + pass else: if platform.machine() == "ppc64le": base_flags += " -mcpu=native" @@ -1213,12 +1231,15 @@ def optimization_flags() -> str: # Internal cannot find libgomp.so if not config.is_fbcode(): - base_flags += " -fopenmp" + if cpp_compiler() in ["cl", "clang-cl"]: + base_flags += " /openmp" + else: + base_flags += " -fopenmp" return base_flags def use_custom_generated_macros() -> str: - return "-D C10_USING_CUSTOM_GENERATED_MACROS" + return "-DC10_USING_CUSTOM_GENERATED_MACROS" def use_fb_internal_macros() -> str: @@ -1406,6 +1427,8 @@ def get_include_and_linking_paths( # and raise error together with instructions at compilation error later else: libs = ["omp"] if config.is_fbcode() else ["gomp"] + if sys.platform == "win32" and "gomp" in libs: + libs.pop(libs.index("gomp")) # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 if not config.aot_inductor.abi_compatible: @@ -1434,6 +1457,11 @@ def get_include_and_linking_paths( lpaths_str = " ".join(["-L" + p for p in lpaths]) libs_str = " ".join(static_link_libs + ["-l" + p for p in libs]) + if sys.platform == "win32": + ipaths = [p.replace(os.sep, "/") for p in ipaths] + lpaths_str = lpaths_str.replace(os.sep, "/") + libs_str = libs_str.replace(os.sep, "/") + return ipaths, lpaths_str, libs_str, macros, build_arch_flags @@ -1454,7 +1482,7 @@ def cpp_compile_command( ) if isinstance(input, str): input = [input] - ipaths_str = " ".join(["-I" + p for p in ipaths]) + ipaths_str = " ".join([f'-I"{p}"' for p in ipaths]) clang_flags = "" if config.is_fbcode(): if aot_mode and not use_absolute_path: @@ -1475,6 +1503,10 @@ def cpp_compile_command( out_name = output linker_paths = "" # let the compiler pick inp_name_str = " ".join(inp_name) + + out_dir = "" + if cpp_compiler() in ["cl", "clang-cl"]: + out_dir = "/Fe:" + os.path.dirname(out_name) + "/" return re.sub( r"[ \n]+", " ", @@ -1489,7 +1521,8 @@ def cpp_compile_command( {use_fb_internal_macros()} {use_standard_sys_dir_headers()} {get_compile_only(compile_only)} - -o {out_name} + {out_dir} + {"-o " if "cl" not in cpp_compiler() else "/LDd /OUT:"}"{out_name}" """, ).strip() @@ -1600,7 +1633,7 @@ def compile( output_so = ( config.aot_inductor.output_path if specified_so_name - else os.path.splitext(input_path)[0] + ".so" + else os.path.splitext(input_path)[0] + (".so" if sys.platform != "win32" else ".dll") ) if not os.path.exists(output_so): @@ -1741,9 +1774,16 @@ def cpp_prefix() -> str: # everything that we compile into a folder for remote compilation. return f'#include "{os.path.basename(filename)}"' else: + filename = filename.replace(os.sep, "/") return f'#include "{filename}"' +@functools.lru_cache(None) +def output_encoding(): + import locale + return locale.getpreferredencoding() + + # Given a path to an input cpp file and an output path, # Attempts to compile the file, storing the output in "output_path" def compile_file( @@ -1781,7 +1821,7 @@ def compile_file( else: subprocess.check_output(cmd, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: - output = e.output.decode("utf-8") + output = e.output.decode(output_encoding()) openmp_problem = "'omp.h' file not found" in output or "libomp" in output if openmp_problem and sys.platform == "darwin": instruction = ( @@ -1834,7 +1874,7 @@ def load(cls, source_code: str) -> CDLL: lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: - output_path = input_path[:-3] + "so" + output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll") if not os.path.exists(output_path): cmd = shlex.split( cpp_compile_command( @@ -1881,7 +1921,7 @@ def load_by_key_path( if key not in cls.cache: with open(path) as f: try: - code = compile(f.read(), path, "exec") + code = compile(f.read(), path.replace(os.sep, "/"), "exec") except Exception as e: raise RuntimeError( f"Failed to import {path}\n{type(e).__name__}: {e}" @@ -1941,7 +1981,7 @@ def load(cls, source_code: str, func_name: str, key: str, cuda: bool) -> CDLL: cpp_wrapper_dir = cpp_wrapper_cache_dir(name) os.makedirs(cpp_wrapper_dir, exist_ok=True) - ext = "so" + ext = "so" if sys.platform != "win32" else "dll" filepath = os.path.join(cpp_wrapper_dir, f"{name}.{ext}") log.debug("Cpp wrapper code path %s", filepath) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 9119d1bf5fb6a0..04f09be31022c9 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3282,7 +3282,12 @@ def codegen_define_and_call(self, wrapper): kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" code.writeline(codecache.cpp_prefix()) - code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})') + code.writeline('#ifdef _MSC_VER') + code.writeline(' #define DLLEXPORT __declspec(dllexport)') + code.writeline('#else') + code.writeline(' #define DLLEXPORT') + code.writeline('#endif') + code.writeline(f'extern "C" DLLEXPORT void {kernel_decl_name}({arg_defs})') with code.indent(): if enable_kernel_profile: graph_id = V.graph.graph_id From ee55add3e20679e603557f9bf2c8c92e462b627a Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 13 Jan 2024 14:57:06 +0900 Subject: [PATCH 2/3] add "cl" for win32 --- torch/_inductor/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index b61e53f09cea05..76b799b5dbef67 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -426,6 +426,9 @@ class cpp: os.environ.get("CXX", "g++"), # "g++.par", ) + if os.name == "nt": + cxx += ("cl",) + # Allow kernel performance profiling via PyTorch profiler enable_kernel_profile = False From 26c25aacb6831bf2787d9a83eaf2c090d22d7d0c Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 14 Jan 2024 16:02:00 +0900 Subject: [PATCH 3/3] use os.replace() to overwrite the existing file --- torch/_inductor/codecache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 011e7a646dc0f3..e1a44e113b0cba 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -384,7 +384,7 @@ def write_atomic(path: str, content: Union[str, bytes]) -> None: write_mode = "w" if isinstance(content, str) else "wb" with tmp_path.open(write_mode) as f: f.write(content) - tmp_path.rename(path) + tmp_path.replace(path) @dataclasses.dataclass