@@ -54,3 +54,79 @@ def failed():
54
54
assert (
55
55
"must implement TransformOpInterface to be used as transform root" in str (e )
56
56
)
57
+
58
+
59
+ print_root_via_include_module = """
60
+ module @print_root_via_include_module attributes {transform.with_named_sequence} {
61
+ transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
62
+ transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
63
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
64
+ transform.include @callee2 failures(propagate)
65
+ (%root) : (!transform.any_op) -> ()
66
+ transform.yield
67
+ }
68
+ }"""
69
+
70
+ callee2_definition = """
71
+ module attributes {transform.with_named_sequence} {
72
+ transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
73
+ transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
74
+ transform.include @callee1 failures(propagate)
75
+ (%root) : (!transform.any_op) -> ()
76
+ transform.yield
77
+ }
78
+ }
79
+ """
80
+
81
+ callee1_definition = """
82
+ module attributes {transform.with_named_sequence} {
83
+ transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
84
+ transform.print %root { name = \" from interpreter\" }: !transform.any_op
85
+ transform.yield
86
+ }
87
+ }
88
+ """
89
+
90
+
91
+ @test_in_context
92
+ def include ():
93
+ main = ir .Module .parse (print_root_via_include_module )
94
+ callee1 = ir .Module .parse (callee1_definition )
95
+ callee2 = ir .Module .parse (callee2_definition )
96
+ interp .copy_symbols_and_merge_into (main , callee1 )
97
+ interp .copy_symbols_and_merge_into (main , callee2 )
98
+
99
+ # CHECK: @print_root_via_include_module
100
+ # CHECK: transform.named_sequence @__transform_main
101
+ # CHECK: transform.include @callee2
102
+ #
103
+ # CHECK: transform.named_sequence @callee1
104
+ # CHECK: transform.print
105
+ #
106
+ # CHECK: transform.named_sequence @callee2
107
+ # CHECK: transform.include @callee1
108
+ interp .apply_named_sequence (main , main .body .operations [0 ], main )
109
+
110
+
111
+ @test_in_context
112
+ def partial_include ():
113
+ main = ir .Module .parse (print_root_via_include_module )
114
+ callee2 = ir .Module .parse (callee2_definition )
115
+ interp .copy_symbols_and_merge_into (main , callee2 )
116
+
117
+ try :
118
+ interp .apply_named_sequence (main , main .body .operations [0 ], main )
119
+ except ValueError as e :
120
+ assert "Failed to apply" in str (e )
121
+
122
+
123
+ @test_in_context
124
+ def repeated_include ():
125
+ main = ir .Module .parse (print_root_via_include_module )
126
+ callee2 = ir .Module .parse (callee2_definition )
127
+ interp .copy_symbols_and_merge_into (main , callee2 )
128
+
129
+ try :
130
+ interp .copy_symbols_and_merge_into (main , callee2 )
131
+ except ValueError as e :
132
+ assert "doubly defined symbol @callee2" in str (e )
0 commit comments