@@ -19,8 +19,11 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
19
19
20
20
public:
21
21
22
- CreateFunctionFromSubroutine (Allocator &al_) :
23
- PassVisitor (al_, nullptr )
22
+ std::map<ASR::symbol_t *, ASR::symbol_t *>& function2subroutine;
23
+
24
+ CreateFunctionFromSubroutine (Allocator &al_,
25
+ std::map<ASR::symbol_t *, ASR::symbol_t *>& function2subroutine_) :
26
+ PassVisitor (al_, nullptr ), function2subroutine(function2subroutine_)
24
27
{
25
28
pass_result.reserve (al, 1 );
26
29
}
@@ -73,6 +76,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
73
76
*/
74
77
if ( PassUtils::is_aggregate_type (s->m_return_var ) ) {
75
78
ASR::symbol_t * s_sub = create_subroutine_from_function (s);
79
+ function2subroutine[item.second ] = s_sub;
76
80
replace_vec.push_back (std::make_pair (item.first , s_sub));
77
81
}
78
82
}
@@ -86,7 +90,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
86
90
// of the function (which returned array) now points
87
91
// to the newly created subroutine.
88
92
for ( auto & item: replace_vec ) {
89
- xx.m_global_scope ->add_symbol (item.first , item.second );
93
+ xx.m_global_scope ->overwrite_symbol (item.first , item.second );
90
94
}
91
95
92
96
// Now visit everything else
@@ -96,6 +100,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
96
100
}
97
101
98
102
void visit_Module (const ASR::Module_t &x) {
103
+ std::vector<std::pair<std::string, ASR::symbol_t *> > replace_vec;
99
104
// FIXME: this is a hack, we need to pass in a non-const `x`,
100
105
// which requires to generate a TransformVisitor.
101
106
ASR::Module_t &xx = const_cast <ASR::Module_t&>(x);
@@ -112,13 +117,21 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
112
117
*/
113
118
if ( PassUtils::is_aggregate_type (s->m_return_var ) ) {
114
119
ASR::symbol_t * s_sub = create_subroutine_from_function (s);
120
+ function2subroutine[item.second ] = s_sub;
115
121
// Update the symtab with this function changes
116
- xx. m_symtab -> overwrite_symbol ( item.first , s_sub);
122
+ replace_vec. push_back ( std::make_pair ( item.first , s_sub) );
117
123
}
118
124
}
119
125
}
120
126
}
121
127
128
+ // Updating the symbol table so that now the name
129
+ // of the function (which returned array) now points
130
+ // to the newly created subroutine.
131
+ for ( auto & item: replace_vec ) {
132
+ current_scope->overwrite_symbol (item.first , item.second );
133
+ }
134
+
122
135
// Now visit everything else
123
136
for (auto &item : x.m_symtab ->get_scope ()) {
124
137
this ->visit_symbol (*item.second );
@@ -154,7 +167,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
154
167
// of the function (which returned array) now points
155
168
// to the newly created subroutine.
156
169
for ( auto & item: replace_vec ) {
157
- current_scope->add_symbol (item.first , item.second );
170
+ current_scope->overwrite_symbol (item.first , item.second );
158
171
}
159
172
160
173
for (auto &item : x.m_symtab ->get_scope ()) {
@@ -183,12 +196,39 @@ class ReplaceFunctionCallWithSubroutineCall: public PassUtils::PassVisitor<Repla
183
196
184
197
public:
185
198
186
- ReplaceFunctionCallWithSubroutineCall (Allocator& al_):
187
- PassVisitor (al_, nullptr ), result_var(nullptr )
199
+ std::map<ASR::symbol_t *, ASR::symbol_t *>& function2subroutine;
200
+
201
+ ReplaceFunctionCallWithSubroutineCall (Allocator& al_,
202
+ std::map<ASR::symbol_t *, ASR::symbol_t *>& function2subroutine_):
203
+ PassVisitor (al_, nullptr ), result_var(nullptr ),
204
+ function2subroutine (function2subroutine_)
188
205
{
189
206
pass_result.reserve (al, 1 );
190
207
}
191
208
209
+ void visit_ExternalSymbol (const ASR::ExternalSymbol_t& x) {
210
+ ASR::ExternalSymbol_t& xx = const_cast <ASR::ExternalSymbol_t&>(x);
211
+ if ( function2subroutine.find (xx.m_external ) != function2subroutine.end () ) {
212
+ xx.m_external = function2subroutine[xx.m_external ];
213
+ }
214
+ }
215
+
216
+ #define visit_ExternalSymbols (Node ) PassVisitor<ReplaceFunctionCallWithSubroutineCall>::visit_##Node(x); \
217
+ for (auto &item : x.m_symtab->get_scope ()) { \
218
+ if ( is_a<ASR::ExternalSymbol_t>(*item.second ) ) { \
219
+ ASR::ExternalSymbol_t* s = ASR::down_cast<ASR::ExternalSymbol_t>(item.second ); \
220
+ visit_ExternalSymbol (*s); \
221
+ } \
222
+ } \
223
+
224
+ void visit_Program (const ASR::Program_t &x) {
225
+ visit_ExternalSymbols (Program)
226
+ }
227
+
228
+ void visit_Module (const ASR::Module_t &x) {
229
+ visit_ExternalSymbols (Module)
230
+ }
231
+
192
232
void visit_Assignment (const ASR::Assignment_t& x) {
193
233
if ( PassUtils::is_aggregate_type (x.m_target ) ) {
194
234
result_var = x.m_target ;
@@ -239,9 +279,10 @@ class ReplaceFunctionCallWithSubroutineCall: public PassUtils::PassVisitor<Repla
239
279
240
280
void pass_create_subroutine_from_function (Allocator &al, ASR::TranslationUnit_t &unit,
241
281
const LCompilers::PassOptions& /* pass_options*/ ) {
242
- CreateFunctionFromSubroutine v (al);
282
+ std::map<ASR::symbol_t *, ASR::symbol_t *> function2subroutine;
283
+ CreateFunctionFromSubroutine v (al, function2subroutine);
243
284
v.visit_TranslationUnit (unit);
244
- ReplaceFunctionCallWithSubroutineCall u (al);
285
+ ReplaceFunctionCallWithSubroutineCall u (al, function2subroutine );
245
286
u.visit_TranslationUnit (unit);
246
287
}
247
288
0 commit comments