@@ -24,24 +24,30 @@ def append_xpu_function_header(src, dst):
24
24
if args .dry_run :
25
25
return
26
26
27
- # Remove trailing empty lines from destination
28
- with open (dst , "r+" , encoding = "utf-8" ) as f :
29
- lines = f .readlines ()
30
- # Remove trailing empty lines
31
- while lines and not lines [- 1 ].strip ():
32
- lines .pop ()
33
- f .seek (0 )
34
- f .truncate ()
35
- f .writelines (lines )
36
-
37
- # Read source file and append matching lines
27
+ # Read source file and match header lines
38
28
with open (src , encoding = "utf-8" ) as fr :
39
29
src_text = fr .read ()
40
30
pattern = r"^#include <ATen/ops/.*>\s*\r?\n"
41
31
matches = re .findall (pattern , src_text , re .MULTILINE )
42
- if matches :
43
- with open (dst , "a" , encoding = "utf-8" ) as fa :
44
- fa .writelines (matches )
32
+ if not matches :
33
+ return
34
+
35
+ with open (dst , "r+" , encoding = "utf-8" ) as f :
36
+ dst_lines = f .readlines ()
37
+ dst_text = "" .join (dst_lines )
38
+ missing_headers = [match for match in matches if match not in dst_text ]
39
+ if not missing_headers :
40
+ return
41
+
42
+ # Remove trailing empty lines from dst_lines
43
+ while dst_lines and not dst_lines [- 1 ].strip ():
44
+ dst_lines .pop ()
45
+
46
+ f .seek (0 )
47
+ f .truncate ()
48
+ f .writelines (dst_lines )
49
+ # Append missing headers to the end of the file
50
+ f .writelines (missing_headers )
45
51
46
52
47
53
def parse_ops_headers (src ):
0 commit comments