|
3 | 3 | #include <libasr/exception.h>
|
4 | 4 | #include <libasr/asr_utils.h>
|
5 | 5 | #include <libasr/asr_verify.h>
|
6 |
| -#include <libasr/pass/pass_utils.h> |
7 | 6 | #include <libasr/pass/class_constructor.h>
|
| 7 | +#include <libasr/pass/pass_utils.h> |
8 | 8 |
|
9 |
| -#include <cstring> |
10 |
| - |
| 9 | +#include <vector> |
| 10 | +#include <utility> |
| 11 | +#include <queue> |
11 | 12 |
|
12 | 13 | namespace LCompilers {
|
13 | 14 |
|
14 | 15 | using ASR::down_cast;
|
15 | 16 | using ASR::is_a;
|
16 | 17 |
|
17 |
| -class ClassConstructorVisitor : public PassUtils::PassVisitor<ClassConstructorVisitor> |
18 |
| -{ |
19 |
| -private: |
| 18 | +class ReplaceStructTypeConstructor: public ASR::BaseExprReplacer<ReplaceStructTypeConstructor> { |
| 19 | + |
| 20 | + private: |
| 21 | + |
| 22 | + Allocator& al; |
| 23 | + Vec<ASR::stmt_t*>& pass_result; |
| 24 | + bool& remove_original_statement; |
| 25 | + bool& inside_symtab; |
| 26 | + std::map<SymbolTable*, Vec<ASR::stmt_t*>>& symtab2decls; |
| 27 | + |
| 28 | + public: |
| 29 | + |
| 30 | + SymbolTable* current_scope; |
20 | 31 | ASR::expr_t* result_var;
|
21 | 32 |
|
22 |
| -public: |
| 33 | + ReplaceStructTypeConstructor(Allocator& al_, Vec<ASR::stmt_t*>& pass_result_, |
| 34 | + bool& remove_original_statement_, bool& inside_symtab_, |
| 35 | + std::map<SymbolTable*, Vec<ASR::stmt_t*>>& symtab2decls_) : |
| 36 | + al(al_), pass_result(pass_result_), |
| 37 | + remove_original_statement(remove_original_statement_), |
| 38 | + inside_symtab(inside_symtab_), symtab2decls(symtab2decls_), |
| 39 | + current_scope(nullptr), result_var(nullptr) {} |
23 | 40 |
|
24 |
| - bool is_constructor_present; |
| 41 | + void replace_StructTypeConstructor(ASR::StructTypeConstructor_t* x) { |
| 42 | + if( x->n_args == 0 ) { |
| 43 | + remove_original_statement = true; |
| 44 | + return ; |
| 45 | + } |
| 46 | + if( result_var == nullptr ) { |
| 47 | + std::string result_var_name = current_scope->get_unique_name("temp_struct_var__"); |
| 48 | + result_var = PassUtils::create_auxiliary_variable(x->base.base.loc, |
| 49 | + result_var_name, al, current_scope, x->m_type); |
| 50 | + *current_expr = result_var; |
| 51 | + } else { |
| 52 | + if( inside_symtab ) { |
| 53 | + *current_expr = nullptr; |
| 54 | + } else { |
| 55 | + remove_original_statement = true; |
| 56 | + } |
| 57 | + } |
25 | 58 |
|
26 |
| - ClassConstructorVisitor(Allocator &al) : PassVisitor(al, nullptr), |
27 |
| - result_var(nullptr), is_constructor_present(false) { |
28 |
| - pass_result.reserve(al, 0); |
29 |
| - } |
| 59 | + std::deque<ASR::symbol_t*> constructor_arg_syms; |
| 60 | + ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(x->m_type); |
| 61 | + ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>( |
| 62 | + ASRUtils::symbol_get_past_external(dt_der->m_derived_type)); |
| 63 | + while( dt_dertype ) { |
| 64 | + for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) { |
| 65 | + constructor_arg_syms.push_front( |
| 66 | + dt_dertype->m_symtab->get_symbol( |
| 67 | + dt_dertype->m_members[i])); |
| 68 | + } |
| 69 | + if( dt_dertype->m_parent != nullptr ) { |
| 70 | + ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external( |
| 71 | + dt_dertype->m_parent); |
| 72 | + LCOMPILERS_ASSERT(ASR::is_a<ASR::StructType_t>(*dt_der_sym)); |
| 73 | + dt_dertype = ASR::down_cast<ASR::StructType_t>(dt_der_sym); |
| 74 | + } else { |
| 75 | + dt_dertype = nullptr; |
| 76 | + } |
| 77 | + } |
| 78 | + LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args); |
30 | 79 |
|
31 |
| - void visit_Assignment(const ASR::Assignment_t& x) { |
32 |
| - ASR::Assignment_t& xx = const_cast<ASR::Assignment_t&>(x); |
33 |
| - if( x.m_value->type == ASR::exprType::StructTypeConstructor ) { |
34 |
| - is_constructor_present = true; |
35 |
| - if( x.m_overloaded == nullptr ) { |
36 |
| - result_var = x.m_target; |
37 |
| - visit_expr(*x.m_value); |
| 80 | + for( size_t i = 0; i < x->n_args; i++ ) { |
| 81 | + if( x->m_args[i].m_value == nullptr ) { |
| 82 | + continue ; |
| 83 | + } |
| 84 | + ASR::symbol_t* member = constructor_arg_syms[i]; |
| 85 | + if( ASR::is_a<ASR::StructTypeConstructor_t>(*x->m_args[i].m_value) ) { |
| 86 | + ASR::expr_t* result_var_copy = result_var; |
| 87 | + ASR::symbol_t *v = nullptr; |
| 88 | + if (ASR::is_a<ASR::Var_t>(*result_var_copy)) { |
| 89 | + v = ASR::down_cast<ASR::Var_t>(result_var_copy)->m_v; |
| 90 | + } |
| 91 | + result_var = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al, |
| 92 | + x->base.base.loc, (ASR::asr_t*) result_var_copy, v, |
| 93 | + member, current_scope)); |
| 94 | + ASR::expr_t** current_expr_copy = current_expr; |
| 95 | + current_expr = &(x->m_args[i].m_value); |
| 96 | + replace_expr(x->m_args[i].m_value); |
| 97 | + current_expr = current_expr_copy; |
| 98 | + result_var = result_var_copy; |
38 | 99 | } else {
|
39 |
| - std::string result_var_name = current_scope->get_unique_name("temp_struct_var__"); |
40 |
| - result_var = PassUtils::create_auxiliary_variable(x.m_value->base.loc, result_var_name, |
41 |
| - al, current_scope, ASRUtils::expr_type(x.m_target)); |
42 |
| - visit_expr(*x.m_value); |
43 |
| - ASR::stmt_t* x_m_overloaded = x.m_overloaded; |
44 |
| - if( ASR::is_a<ASR::SubroutineCall_t>(*x.m_overloaded) ) { |
45 |
| - ASR::SubroutineCall_t* assign_call = ASR::down_cast<ASR::SubroutineCall_t>(xx.m_overloaded); |
46 |
| - Vec<ASR::call_arg_t> assign_call_args; |
47 |
| - assign_call_args.reserve(al, 2); |
48 |
| - assign_call_args.push_back(al, assign_call->m_args[0]); |
49 |
| - ASR::call_arg_t arg_1; |
50 |
| - arg_1.loc = assign_call->m_args[1].loc; |
51 |
| - arg_1.m_value = result_var; |
52 |
| - assign_call_args.push_back(al, arg_1); |
53 |
| - x_m_overloaded = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x_m_overloaded->base.loc, |
54 |
| - assign_call->m_name, assign_call->m_original_name, assign_call_args.p, |
55 |
| - assign_call_args.size(), assign_call->m_dt)); |
| 100 | + Vec<ASR::stmt_t*>* result_vec = nullptr; |
| 101 | + if( inside_symtab ) { |
| 102 | + if( symtab2decls.find(current_scope) == symtab2decls.end() ) { |
| 103 | + Vec<ASR::stmt_t*> result_vec_; |
| 104 | + result_vec_.reserve(al, 0); |
| 105 | + symtab2decls[current_scope] = result_vec_; |
| 106 | + } |
| 107 | + result_vec = &symtab2decls[current_scope]; |
| 108 | + } else { |
| 109 | + result_vec = &pass_result; |
56 | 110 | }
|
57 |
| - pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, |
58 |
| - result_var, x_m_overloaded))); |
| 111 | + ASR::symbol_t *v = nullptr; |
| 112 | + if (ASR::is_a<ASR::Var_t>(*result_var)) { |
| 113 | + v = ASR::down_cast<ASR::Var_t>(result_var)->m_v; |
| 114 | + } |
| 115 | + ASR::expr_t* derived_ref = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al, |
| 116 | + x->base.base.loc, (ASR::asr_t*) result_var, v, |
| 117 | + member, current_scope)); |
| 118 | + ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, |
| 119 | + x->base.base.loc, derived_ref, |
| 120 | + x->m_args[i].m_value, nullptr)); |
| 121 | + result_vec->push_back(al, assign); |
59 | 122 | }
|
60 | 123 | }
|
61 |
| - |
62 | 124 | }
|
| 125 | +}; |
63 | 126 |
|
64 |
| - void visit_StructTypeConstructor(const ASR::StructTypeConstructor_t &x) { |
65 |
| - if( x.n_args == 0 ) { |
66 |
| - remove_original_stmt = true; |
| 127 | +class StructTypeConstructorVisitor : public ASR::CallReplacerOnExpressionsVisitor<StructTypeConstructorVisitor> |
| 128 | +{ |
| 129 | + private: |
| 130 | + |
| 131 | + Allocator& al; |
| 132 | + bool remove_original_statement; |
| 133 | + bool inside_symtab; |
| 134 | + ReplaceStructTypeConstructor replacer; |
| 135 | + Vec<ASR::stmt_t*> pass_result; |
| 136 | + std::map<SymbolTable*, Vec<ASR::stmt_t*>> symtab2decls; |
| 137 | + |
| 138 | + public: |
| 139 | + |
| 140 | + StructTypeConstructorVisitor(Allocator& al_) : |
| 141 | + al(al_), remove_original_statement(false), |
| 142 | + inside_symtab(true), replacer(al_, pass_result, |
| 143 | + remove_original_statement, inside_symtab, |
| 144 | + symtab2decls) { |
| 145 | + pass_result.n = 0; |
| 146 | + pass_result.reserve(al, 0); |
67 | 147 | }
|
68 |
| - ASR::Struct_t* dt_der = down_cast<ASR::Struct_t>(x.m_type); |
69 |
| - ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>( |
70 |
| - ASRUtils::symbol_get_past_external(dt_der->m_derived_type)); |
71 |
| - for( size_t i = 0; i < std::min(dt_dertype->n_members, x.n_args); i++ ) { |
72 |
| - ASR::symbol_t* member = dt_dertype->m_symtab->resolve_symbol(std::string(dt_dertype->m_members[i], strlen(dt_dertype->m_members[i]))); |
73 |
| - ASR::symbol_t *v = nullptr; |
74 |
| - if (is_a<ASR::Var_t>(*result_var)) { |
75 |
| - v = down_cast<ASR::Var_t>(result_var)->m_v; |
| 148 | + |
| 149 | + void call_replacer() { |
| 150 | + replacer.current_expr = current_expr; |
| 151 | + replacer.current_scope = current_scope; |
| 152 | + replacer.replace_expr(*current_expr); |
| 153 | + } |
| 154 | + |
| 155 | + void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) { |
| 156 | + bool inside_symtab_copy = inside_symtab; |
| 157 | + inside_symtab = false; |
| 158 | + Vec<ASR::stmt_t*> body; |
| 159 | + body.reserve(al, n_body); |
| 160 | + |
| 161 | + if( symtab2decls.find(current_scope) != symtab2decls.end() ) { |
| 162 | + Vec<ASR::stmt_t*>& decls = symtab2decls[current_scope]; |
| 163 | + for (size_t j = 0; j < decls.size(); j++) { |
| 164 | + body.push_back(al, decls[j]); |
| 165 | + } |
| 166 | + symtab2decls.erase(current_scope); |
| 167 | + } |
| 168 | + |
| 169 | + for (size_t i = 0; i < n_body; i++) { |
| 170 | + pass_result.n = 0; |
| 171 | + pass_result.reserve(al, 1); |
| 172 | + remove_original_statement = false; |
| 173 | + replacer.result_var = nullptr; |
| 174 | + visit_stmt(*m_body[i]); |
| 175 | + for (size_t j = 0; j < pass_result.size(); j++) { |
| 176 | + body.push_back(al, pass_result[j]); |
| 177 | + } |
| 178 | + if( !remove_original_statement ) { |
| 179 | + body.push_back(al, m_body[i]); |
| 180 | + } |
| 181 | + remove_original_statement = false; |
76 | 182 | }
|
77 |
| - ASR::expr_t* derived_ref = ASRUtils::EXPR(ASRUtils::getStructInstanceMember_t(al, x.base.base.loc, (ASR::asr_t*)result_var, v, member, current_scope)); |
78 |
| - ASR::stmt_t* assign = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, derived_ref, x.m_args[i], nullptr)); |
79 |
| - pass_result.push_back(al, assign); |
| 183 | + m_body = body.p; |
| 184 | + n_body = body.size(); |
| 185 | + replacer.result_var = nullptr; |
| 186 | + pass_result.n = 0; |
| 187 | + pass_result.reserve(al, 0); |
| 188 | + inside_symtab = inside_symtab_copy; |
80 | 189 | }
|
81 |
| - } |
| 190 | + |
| 191 | + void visit_Variable(const ASR::Variable_t &x) { |
| 192 | + ASR::Variable_t& xx = const_cast<ASR::Variable_t&>(x); |
| 193 | + replacer.result_var = ASRUtils::EXPR(ASR::make_Var_t(al, |
| 194 | + x.base.base.loc, &(xx.base))); |
| 195 | + ASR::CallReplacerOnExpressionsVisitor< |
| 196 | + StructTypeConstructorVisitor>::visit_Variable(x); |
| 197 | + } |
| 198 | + |
| 199 | + void visit_Assignment(const ASR::Assignment_t &x) { |
| 200 | + if (x.m_overloaded) { |
| 201 | + this->visit_stmt(*x.m_overloaded); |
| 202 | + remove_original_statement = false; |
| 203 | + return ; |
| 204 | + } |
| 205 | + |
| 206 | + replacer.result_var = x.m_target; |
| 207 | + ASR::expr_t** current_expr_copy_9 = current_expr; |
| 208 | + current_expr = const_cast<ASR::expr_t**>(&(x.m_value)); |
| 209 | + this->call_replacer(); |
| 210 | + current_expr = current_expr_copy_9; |
| 211 | + if( !remove_original_statement ) { |
| 212 | + this->visit_expr(*x.m_value); |
| 213 | + } |
| 214 | + } |
| 215 | + |
82 | 216 | };
|
83 | 217 |
|
84 |
| -void pass_replace_class_constructor(Allocator &al, ASR::TranslationUnit_t &unit, |
85 |
| - const LCompilers::PassOptions& /*pass_options*/) { |
86 |
| - ClassConstructorVisitor v(al); |
87 |
| - do { |
88 |
| - v.is_constructor_present = false; |
89 |
| - v.visit_TranslationUnit(unit); |
90 |
| - } while( v.is_constructor_present ); |
| 218 | +void pass_replace_class_constructor(Allocator &al, |
| 219 | + ASR::TranslationUnit_t &unit, |
| 220 | + const LCompilers::PassOptions& /*pass_options*/) { |
| 221 | + StructTypeConstructorVisitor v(al); |
| 222 | + v.visit_TranslationUnit(unit); |
| 223 | + PassUtils::UpdateDependenciesVisitor w(al); |
| 224 | + w.visit_TranslationUnit(unit); |
91 | 225 | }
|
92 | 226 |
|
93 | 227 |
|
|
0 commit comments