Skip to content

Commit c1243f4

Browse files
committed
resync
1 parent 85eadac commit c1243f4

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/libasr/asr_verify.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,8 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
610610
+ "') + scope_names not found in a module '"
611611
+ asr_owner_name + "'");
612612
require(s == x.m_external,
613-
"ExternalSymbol::m_name + scope_names found but not equal to m_external");
613+
std::string("ExternalSymbol::m_name + scope_names found but not equal to m_external, ") +
614+
"original_name " + std::string(x.m_original_name) + ".");
614615
}
615616
}
616617

src/libasr/pass/subroutine_from_function.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
1919

2020
public:
2121

22-
CreateFunctionFromSubroutine(Allocator &al_) :
23-
PassVisitor(al_, nullptr)
22+
std::map<ASR::symbol_t*, ASR::symbol_t*>& function2subroutine;
23+
24+
CreateFunctionFromSubroutine(Allocator &al_,
25+
std::map<ASR::symbol_t*, ASR::symbol_t*>& function2subroutine_) :
26+
PassVisitor(al_, nullptr), function2subroutine(function2subroutine_)
2427
{
2528
pass_result.reserve(al, 1);
2629
}
@@ -73,6 +76,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
7376
*/
7477
if( PassUtils::is_aggregate_type(s->m_return_var) ) {
7578
ASR::symbol_t* s_sub = create_subroutine_from_function(s);
79+
function2subroutine[item.second] = s_sub;
7680
replace_vec.push_back(std::make_pair(item.first, s_sub));
7781
}
7882
}
@@ -86,7 +90,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
8690
// of the function (which returned array) now points
8791
// to the newly created subroutine.
8892
for( auto& item: replace_vec ) {
89-
xx.m_global_scope->add_symbol(item.first, item.second);
93+
xx.m_global_scope->overwrite_symbol(item.first, item.second);
9094
}
9195

9296
// Now visit everything else
@@ -96,6 +100,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
96100
}
97101

98102
void visit_Module(const ASR::Module_t &x) {
103+
std::vector<std::pair<std::string, ASR::symbol_t*> > replace_vec;
99104
// FIXME: this is a hack, we need to pass in a non-const `x`,
100105
// which requires to generate a TransformVisitor.
101106
ASR::Module_t &xx = const_cast<ASR::Module_t&>(x);
@@ -112,13 +117,21 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
112117
*/
113118
if( PassUtils::is_aggregate_type(s->m_return_var) ) {
114119
ASR::symbol_t* s_sub = create_subroutine_from_function(s);
120+
function2subroutine[item.second] = s_sub;
115121
// Update the symtab with this function changes
116-
xx.m_symtab->overwrite_symbol(item.first, s_sub);
122+
replace_vec.push_back(std::make_pair(item.first, s_sub));
117123
}
118124
}
119125
}
120126
}
121127

128+
// Updating the symbol table so that now the name
129+
// of the function (which returned array) now points
130+
// to the newly created subroutine.
131+
for( auto& item: replace_vec ) {
132+
current_scope->overwrite_symbol(item.first, item.second);
133+
}
134+
122135
// Now visit everything else
123136
for (auto &item : x.m_symtab->get_scope()) {
124137
this->visit_symbol(*item.second);
@@ -154,7 +167,7 @@ class CreateFunctionFromSubroutine: public PassUtils::PassVisitor<CreateFunction
154167
// of the function (which returned array) now points
155168
// to the newly created subroutine.
156169
for( auto& item: replace_vec ) {
157-
current_scope->add_symbol(item.first, item.second);
170+
current_scope->overwrite_symbol(item.first, item.second);
158171
}
159172

160173
for (auto &item : x.m_symtab->get_scope()) {
@@ -183,12 +196,39 @@ class ReplaceFunctionCallWithSubroutineCall: public PassUtils::PassVisitor<Repla
183196

184197
public:
185198

186-
ReplaceFunctionCallWithSubroutineCall(Allocator& al_):
187-
PassVisitor(al_, nullptr), result_var(nullptr)
199+
std::map<ASR::symbol_t*, ASR::symbol_t*>& function2subroutine;
200+
201+
ReplaceFunctionCallWithSubroutineCall(Allocator& al_,
202+
std::map<ASR::symbol_t*, ASR::symbol_t*>& function2subroutine_):
203+
PassVisitor(al_, nullptr), result_var(nullptr),
204+
function2subroutine(function2subroutine_)
188205
{
189206
pass_result.reserve(al, 1);
190207
}
191208

209+
void visit_ExternalSymbol(const ASR::ExternalSymbol_t& x) {
210+
ASR::ExternalSymbol_t& xx = const_cast<ASR::ExternalSymbol_t&>(x);
211+
if( function2subroutine.find(xx.m_external) != function2subroutine.end() ) {
212+
xx.m_external = function2subroutine[xx.m_external];
213+
}
214+
}
215+
216+
#define visit_ExternalSymbols(Node) PassVisitor<ReplaceFunctionCallWithSubroutineCall>::visit_##Node(x); \
217+
for (auto &item : x.m_symtab->get_scope()) { \
218+
if( is_a<ASR::ExternalSymbol_t>(*item.second) ) { \
219+
ASR::ExternalSymbol_t* s = ASR::down_cast<ASR::ExternalSymbol_t>(item.second); \
220+
visit_ExternalSymbol(*s); \
221+
} \
222+
} \
223+
224+
void visit_Program(const ASR::Program_t &x) {
225+
visit_ExternalSymbols(Program)
226+
}
227+
228+
void visit_Module(const ASR::Module_t &x) {
229+
visit_ExternalSymbols(Module)
230+
}
231+
192232
void visit_Assignment(const ASR::Assignment_t& x) {
193233
if( PassUtils::is_aggregate_type(x.m_target) ) {
194234
result_var = x.m_target;
@@ -239,9 +279,10 @@ class ReplaceFunctionCallWithSubroutineCall: public PassUtils::PassVisitor<Repla
239279

240280
void pass_create_subroutine_from_function(Allocator &al, ASR::TranslationUnit_t &unit,
241281
const LCompilers::PassOptions& /*pass_options*/) {
242-
CreateFunctionFromSubroutine v(al);
282+
std::map<ASR::symbol_t*, ASR::symbol_t*> function2subroutine;
283+
CreateFunctionFromSubroutine v(al, function2subroutine);
243284
v.visit_TranslationUnit(unit);
244-
ReplaceFunctionCallWithSubroutineCall u(al);
285+
ReplaceFunctionCallWithSubroutineCall u(al, function2subroutine);
245286
u.visit_TranslationUnit(unit);
246287
}
247288

0 commit comments

Comments
 (0)