Skip to content

Commit b04c7b7

Browse files
committed
Only append if header miss
1 parent 0e28e06 commit b04c7b7

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

cmake/Codegen.cmake

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ else()
2929
endif()
3030

3131
set(XPU_CODEGEN_COMMAND
32-
"${Python_EXECUTABLE}" -m torchgen.gen
33-
--source-path ${CODEGEN_XPU_YAML_DIR}
34-
--install-dir ${BUILD_TORCH_XPU_ATEN_GENERATED}
35-
--per-operator-headers
36-
--backend-whitelist XPU SparseXPU SparseCsrXPU NestedTensorXPU
37-
--xpu
38-
)
32+
"${Python_EXECUTABLE}" -m torchgen.gen
33+
--source-path ${CODEGEN_XPU_YAML_DIR}
34+
--install-dir ${BUILD_TORCH_XPU_ATEN_GENERATED}
35+
--per-operator-headers
36+
--backend-whitelist XPU SparseXPU SparseCsrXPU NestedTensorXPU
37+
--xpu
38+
)
3939

4040
set(XPU_INSTALL_HEADER_COMMAND
41-
"${Python_EXECUTABLE}" ${TORCH_XPU_OPS_ROOT}/tools/codegen/install_xpu_headers.py
42-
--src-header-dir ${BUILD_TORCH_XPU_ATEN_GENERATED}
43-
--dst-header-dir ${BUILD_TORCH_ATEN_GENERATED}
44-
)
41+
"${Python_EXECUTABLE}" ${TORCH_XPU_OPS_ROOT}/tools/codegen/install_xpu_headers.py
42+
--src-header-dir ${BUILD_TORCH_XPU_ATEN_GENERATED}
43+
--dst-header-dir ${BUILD_TORCH_ATEN_GENERATED}
44+
)
4545

4646
# Generate ops_generated_headers.cmake for torch-xpu-ops
4747
execute_process(
@@ -121,7 +121,7 @@ add_custom_command(
121121
WORKING_DIRECTORY ${TORCH_ROOT}
122122
)
123123

124-
# Codegen post progress
124+
# Codegen post progress
125125
if(WIN32)
126126
add_custom_target(DELETE_TEMPLATES ALL DEPENDS ${OUTPUT_LIST})
127127
# Delete the copied templates folder only on Windows.

tools/codegen/install_xpu_headers.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,30 @@ def append_xpu_function_header(src, dst):
2424
if args.dry_run:
2525
return
2626

27-
# Remove trailing empty lines from destination
28-
with open(dst, "r+", encoding="utf-8") as f:
29-
lines = f.readlines()
30-
# Remove trailing empty lines
31-
while lines and not lines[-1].strip():
32-
lines.pop()
33-
f.seek(0)
34-
f.truncate()
35-
f.writelines(lines)
36-
37-
# Read source file and append matching lines
27+
# Read source file and match header lines
3828
with open(src, encoding="utf-8") as fr:
3929
src_text = fr.read()
4030
pattern = r"^#include <ATen/ops/.*>\s*\r?\n"
4131
matches = re.findall(pattern, src_text, re.MULTILINE)
42-
if matches:
43-
with open(dst, "a", encoding="utf-8") as fa:
44-
fa.writelines(matches)
32+
if not matches:
33+
return
34+
35+
with open(dst, "r+", encoding="utf-8") as f:
36+
dst_lines = f.readlines()
37+
dst_text = "".join(dst_lines)
38+
missing_headers = [match for match in matches if match not in dst_text]
39+
if not missing_headers:
40+
return
41+
42+
# Remove trailing empty lines from dst_lines
43+
while dst_lines and not dst_lines[-1].strip():
44+
dst_lines.pop()
45+
46+
f.seek(0)
47+
f.truncate()
48+
f.writelines(dst_lines)
49+
# Append missing headers to the end of the file
50+
f.writelines(missing_headers)
4551

4652

4753
def parse_ops_headers(src):

0 commit comments

Comments
 (0)