@@ -47,7 +47,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
47
47
bool fixed_duplicated_expr_stmt;
48
48
bool is_fast;
49
49
50
- // Stores the local variables corresponding to the ones
50
+ // Stores the local variables or/and block symbol corresponding to the ones
51
51
// present in function symbol table.
52
52
std::map<std::string, ASR::symbol_t *> arg2value;
53
53
@@ -95,33 +95,42 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
95
95
current_routine.clear ();
96
96
}
97
97
98
+ // If anything is not local to a function being inlined
99
+ // then do not inline the function by setting
100
+ // fixed_duplicated_expr_stmt to false.
101
+ // To be supported later.
102
+ #define replace_symbol (sym, symbol_t, m_v ) \
103
+ std::string sym_name = ASRUtils::symbol_name(sym); \
104
+ if ( current_routine_scope && \
105
+ current_routine_scope->get_symbol (sym_name) == nullptr ) { \
106
+ fixed_duplicated_expr_stmt = false ; \
107
+ return ; \
108
+ } \
109
+ LCOMPILERS_ASSERT (ASR::is_a<symbol_t >(*sym)) \
110
+ if( arg2value.find(sym_name) != arg2value.end() ) { \
111
+ symbol_t *x_var = ASR::down_cast<symbol_t >(arg2value[sym_name]); \
112
+ if ( current_scope->get_symbol (std::string (x_var->m_name ))) { \
113
+ m_v = arg2value[sym_name]; \
114
+ } \
115
+ } else { \
116
+ fixed_duplicated_expr_stmt = false ; \
117
+ }
118
+
98
119
void visit_Var (const ASR::Var_t& x) {
99
120
ASR::Var_t& xx = const_cast <ASR::Var_t&>(x);
100
- std::string x_var_name = std::string (ASRUtils::symbol_name (x.m_v ));
101
-
102
- // If anything is not local to a function being inlined
103
- // then do not inline the function by setting
104
- // fixed_duplicated_expr_stmt to false.
105
- // To be supported later.
106
- if ( current_routine_scope &&
107
- current_routine_scope->get_symbol (x_var_name) == nullptr ) {
108
- fixed_duplicated_expr_stmt = false ;
109
- return ;
110
- }
111
- if ( x.m_v ->type == ASR::symbolType::Variable ) {
112
- ASR::Variable_t* x_var = ASR::down_cast<ASR::Variable_t>(x.m_v );
113
- if ( arg2value.find (x_var_name) != arg2value.end () ) {
114
- x_var = ASR::down_cast<ASR::Variable_t>(arg2value[x_var_name]);
115
- if ( current_scope->get_symbol (std::string (x_var->m_name )) != nullptr ) {
116
- xx.m_v = arg2value[x_var_name];
117
- }
118
- x_var = ASR::down_cast<ASR::Variable_t>(x.m_v );
119
- }
121
+ ASR::symbol_t *sym = ASRUtils::symbol_get_past_external (x.m_v );
122
+ if (ASR::is_a<ASR::EnumType_t>(*sym)) {
123
+ replace_symbol (sym, ASR::EnumType_t, xx.m_v );
120
124
} else {
121
- fixed_duplicated_expr_stmt = false ;
125
+ replace_symbol (sym, ASR::Variable_t, xx. m_v ) ;
122
126
}
123
127
}
124
128
129
+ void visit_BlockCall (const ASR::BlockCall_t &x) {
130
+ ASR::BlockCall_t& xx = const_cast <ASR::BlockCall_t&>(x);
131
+ replace_symbol (x.m_m , ASR::Block_t, xx.m_m );
132
+ }
133
+
125
134
void set_empty_block (SymbolTable* scope, const Location& loc) {
126
135
std::string empty_block_name = scope->get_unique_name (" ~empty_block" );
127
136
if ( empty_block_name != " ~empty_block" ) {
@@ -133,6 +142,7 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
133
142
s2c (al, empty_block_name), nullptr , 0 ));
134
143
scope->add_symbol (empty_block_name, empty_block);
135
144
}
145
+ arg2value[empty_block_name] = empty_block;
136
146
}
137
147
138
148
void remove_empty_block (SymbolTable* scope) {
0 commit comments