diff --git a/src/libasr/pass/inline_function_calls.cpp b/src/libasr/pass/inline_function_calls.cpp index 33336dadce..a013f3894d 100644 --- a/src/libasr/pass/inline_function_calls.cpp +++ b/src/libasr/pass/inline_function_calls.cpp @@ -47,7 +47,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor arg2value; @@ -95,33 +95,40 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitorget_symbol(sym_name) == nullptr ) { \ + fixed_duplicated_expr_stmt = false; \ + return ; \ + } \ + if( arg2value.find(sym_name) != arg2value.end() ) { \ + LCOMPILERS_ASSERT(ASR::is_a(*sym)) \ + symbol_t *x_var = ASR::down_cast(arg2value[sym_name]); \ + if( current_scope->get_symbol(std::string(x_var->m_name))) { \ + m_v = arg2value[sym_name]; \ + } \ + } + void visit_Var(const ASR::Var_t& x) { ASR::Var_t& xx = const_cast(x); - std::string x_var_name = std::string(ASRUtils::symbol_name(x.m_v)); - - // If anything is not local to a function being inlined - // then do not inline the function by setting - // fixed_duplicated_expr_stmt to false. - // To be supported later. - if( current_routine_scope && - current_routine_scope->get_symbol(x_var_name) == nullptr ) { - fixed_duplicated_expr_stmt = false; - return ; - } - if( x.m_v->type == ASR::symbolType::Variable ) { - ASR::Variable_t* x_var = ASR::down_cast(x.m_v); - if( arg2value.find(x_var_name) != arg2value.end() ) { - x_var = ASR::down_cast(arg2value[x_var_name]); - if( current_scope->get_symbol(std::string(x_var->m_name)) != nullptr ) { - xx.m_v = arg2value[x_var_name]; - } - x_var = ASR::down_cast(x.m_v); - } + ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(x.m_v); + if (ASR::is_a(*sym)) { + replace_symbol(sym, ASR::Variable_t, xx.m_v); } else { fixed_duplicated_expr_stmt = false; } } + void visit_BlockCall(const ASR::BlockCall_t &x) { + ASR::BlockCall_t& xx = const_cast(x); + replace_symbol(x.m_m, ASR::Block_t, xx.m_m); + } + void set_empty_block(SymbolTable* scope, const Location& loc) { std::string empty_block_name = scope->get_unique_name("~empty_block"); if( empty_block_name != "~empty_block" ) { @@ -133,6 +140,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitoradd_symbol(empty_block_name, empty_block); } + arg2value[empty_block_name] = empty_block; } void remove_empty_block(SymbolTable* scope) {