From d88181460ac041fcf61ac8134454be1b007ec028 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 17 Feb 2023 10:17:24 +0530 Subject: [PATCH] Revamp arr_slice.cpp --- src/libasr/asdl_cpp.py | 31 +++- src/libasr/pass/arr_slice.cpp | 312 +++++++++++++--------------------- 2 files changed, 150 insertions(+), 193 deletions(-) diff --git a/src/libasr/asdl_cpp.py b/src/libasr/asdl_cpp.py index 801eec0686..05d93a7e1a 100644 --- a/src/libasr/asdl_cpp.py +++ b/src/libasr/asdl_cpp.py @@ -422,8 +422,14 @@ def visitModule(self, mod): self.emit(" Struct& self() { return static_cast(*this); }") self.emit("public:") self.emit(" ASR::expr_t** current_expr;") + self.emit(" SymbolTable* current_scope;") self.emit("") self.emit(" void call_replacer() {}") + self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {") + self.emit(" for (size_t i = 0; i < n_body; i++) {", 1) + self.emit(" self().visit_stmt(*m_body[i]);", 1) + self.emit(" }", 1) + self.emit(" }") super(CallReplacerOnExpressionsVisitor, self).visitModule(mod) self.emit("};") @@ -440,14 +446,34 @@ def visitConstructor(self, cons, _): def make_visitor(self, name, fields): self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1) + is_symtab_present = False + is_stmt_present = False + symtab_field_name = "" + for field in fields: + if field.type == "stmt": + is_stmt_present = True + if field.type == "symbol_table": + is_symtab_present = True + symtab_field_name = field.name + if is_stmt_present and is_symtab_present: + break + if is_stmt_present and name not in ("Assignment", "ForAllSingle"): + self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1) self.used = False - have_body = False + + if is_symtab_present: + self.emit("SymbolTable* current_scope_copy = current_scope;", 2) + self.emit("current_scope = x.m_%s;" % symtab_field_name, 2) + for field in fields: self.visitField(field) if not self.used: # Note: a better solution would be to change `&x` to `& /* x */` # above, but we would need to change emit to return a string. self.emit("if ((bool&)x) { } // Suppress unused warning", 2) + + if is_symtab_present: + self.emit("current_scope = current_scope_copy;", 2) self.emit("}", 1) def insert_call_replacer_code(self, name, level, index=""): @@ -462,6 +488,9 @@ def visitField(self, field): field.type not in self.data.simple_types): level = 2 if field.seq: + if field.type == "stmt": + self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level) + return self.used = True self.emit("for (size_t i=0; i #include #include -#include #include -#include -#include +#include + +#include #include namespace LCompilers { -using ASR::down_cast; -using ASR::is_a; - /* This ASR pass replaces array slice with do loops and array expression assignments. The function `pass_replace_arr_slice` transforms the ASR tree in-place. @@ -30,257 +27,188 @@ The function `pass_replace_arr_slice` transforms the ASR tree in-place. end do */ -class ArrSliceVisitor : public PassUtils::PassVisitor -{ -private: +class ReplaceArraySection: public ASR::BaseExprReplacer { - ASR::expr_t* slice_var; - bool create_slice_var; + private: - int slice_counter; + Allocator& al; + Vec& pass_result; + size_t slice_counter; - std::string rl_path; + public: -public: - ArrSliceVisitor(Allocator &al, const std::string &rl_path) : PassVisitor(al, nullptr), - slice_var(nullptr), create_slice_var(false), slice_counter(0), - rl_path(rl_path) - { - pass_result.reserve(al, 1); - } + SymbolTable* current_scope; + + ReplaceArraySection(Allocator& al_, Vec& pass_result_) : + al(al_), pass_result(pass_result_), slice_counter(0), current_scope(nullptr) {} - ASR::ttype_t* get_array_from_slice(const ASR::ArraySection_t& x, ASR::expr_t* arr_var) { + ASR::ttype_t* get_array_from_slice(ASR::ArraySection_t* x, ASR::expr_t* arr_var) { Vec m_dims; - m_dims.reserve(al, x.n_args); - ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4, nullptr, 0)); - ASR::expr_t* const_1 = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)); - for( size_t i = 0; i < x.n_args; i++ ) { - if( x.m_args[i].m_step != nullptr ) { + m_dims.reserve(al, x->n_args); + ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x->base.base.loc, 4, nullptr, 0)); + ASR::expr_t* const_1 = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x->base.base.loc, 1, int32_type)); + for( size_t i = 0; i < x->n_args; i++ ) { + if( x->m_args[i].m_step != nullptr ) { ASR::expr_t *start = nullptr, *end = nullptr, *step = nullptr; - if( x.m_args[i].m_left == nullptr ) { + if( x->m_args[i].m_left == nullptr ) { start = PassUtils::get_bound(arr_var, i + 1, "lbound", al); } else { - start = x.m_args[i].m_left; + start = x->m_args[i].m_left; } - if( x.m_args[i].m_right == nullptr ) { + if( x->m_args[i].m_right == nullptr ) { end = PassUtils::get_bound(arr_var, i + 1, "ubound", al); } else { - end = x.m_args[i].m_right; + end = x->m_args[i].m_right; } - if( x.m_args[i].m_step == nullptr ) { + if( x->m_args[i].m_step == nullptr ) { step = const_1; } else { - step = x.m_args[i].m_step; + step = x->m_args[i].m_step; } start = PassUtils::to_int32(start, int32_type, al); end = PassUtils::to_int32(end, int32_type, al); step = PassUtils::to_int32(step, int32_type, al); - ASR::expr_t* gap = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, + ASR::expr_t* gap = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x->base.base.loc, end, ASR::binopType::Sub, start, int32_type, nullptr)); - ASR::expr_t* slice_size = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, + ASR::expr_t* slice_size = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x->base.base.loc, gap, ASR::binopType::Div, step, int32_type, nullptr)); - ASR::expr_t* actual_size = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, + ASR::expr_t* actual_size = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x->base.base.loc, slice_size, ASR::binopType::Add, const_1, int32_type, nullptr)); ASR::dimension_t curr_dim; - curr_dim.loc = x.base.base.loc; + curr_dim.loc = x->base.base.loc; curr_dim.m_start = const_1; curr_dim.m_length = actual_size; m_dims.push_back(al, curr_dim); } else { ASR::dimension_t curr_dim; - curr_dim.loc = x.base.base.loc; + curr_dim.loc = x->base.base.loc; curr_dim.m_start = const_1; curr_dim.m_length = const_1; m_dims.push_back(al, curr_dim); } } - ASR::ttype_t* new_type = nullptr; - ASR::ttype_t* t2 = ASRUtils::type_get_past_pointer(x.m_type); - switch (t2->type) - { - case ASR::ttypeType::Integer: { - ASR::Integer_t* curr_type = down_cast(t2); - new_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, curr_type->m_kind, - m_dims.p, m_dims.size())); - break; - } - case ASR::ttypeType::Real: { - ASR::Real_t* curr_type = down_cast(t2); - new_type = ASRUtils::TYPE(ASR::make_Real_t(al, x.base.base.loc, curr_type->m_kind, - m_dims.p, m_dims.size())); - break; - } - case ASR::ttypeType::Complex: { - ASR::Complex_t* curr_type = down_cast(t2); - new_type = ASRUtils::TYPE(ASR::make_Complex_t(al, x.base.base.loc, curr_type->m_kind, - m_dims.p, m_dims.size())); - break; - } - case ASR::ttypeType::Logical: { - ASR::Logical_t* curr_type = down_cast(t2); - new_type = ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, curr_type->m_kind, - m_dims.p, m_dims.size())); - break; - } - case ASR::ttypeType::Struct: { - ASR::Struct_t* curr_type = down_cast(t2); - new_type = ASRUtils::TYPE(ASR::make_Struct_t(al, x.base.base.loc, curr_type->m_derived_type, - m_dims.p, m_dims.size())); - break; - } - default: - break; - } - if (ASR::is_a(*x.m_type)) { - new_type = ASRUtils::TYPE(ASR::make_Pointer_t(al, x.base.base.loc, new_type)); + ASR::ttype_t* t2 = ASRUtils::type_get_past_pointer(x->m_type); + ASR::ttype_t* new_type = ASRUtils::duplicate_type(al, t2, &m_dims); + if (ASR::is_a(*x->m_type)) { + new_type = ASRUtils::TYPE(ASR::make_Pointer_t(al, x->base.base.loc, new_type)); } return new_type; } - void visit_ArraySection(const ASR::ArraySection_t& x) { - if( create_slice_var ) { - ASR::expr_t* x_arr_var = x.m_v; - Str new_name_str; - new_name_str.from_str(al, "~" + std::to_string(slice_counter) + "_slice"); - slice_counter += 1; - char* new_var_name = (char*)new_name_str.c_str(al); - ASR::asr_t* slice_asr = ASR::make_Variable_t(al, x.base.base.loc, current_scope, new_var_name, nullptr, 0, - ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default, - get_array_from_slice(x, x_arr_var), ASR::abiType::Source, ASR::accessType::Public, - ASR::presenceType::Required, false); - ASR::symbol_t* slice_sym = ASR::down_cast(slice_asr); - current_scope->add_symbol(std::string(new_var_name), slice_sym); - slice_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, slice_sym)); - Vec idx_vars_target, idx_vars_value; - PassUtils::create_idx_vars(idx_vars_target, x.n_args, x.base.base.loc, al, current_scope, "_t"); - PassUtils::create_idx_vars(idx_vars_value, x.n_args, x.base.base.loc, al, current_scope, "_v"); - ASR::stmt_t* doloop = nullptr; - int a_kind = ASRUtils::extract_kind_from_ttype_t(ASRUtils::expr_type(x.m_v)); - ASR::ttype_t* int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, a_kind, nullptr, 0)); - ASR::expr_t* const_1 = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int_type)); - for( int i = (int)x.n_args - 1; i >= 0; i-- ) { - ASR::do_loop_head_t head; - head.m_v = idx_vars_value[i]; - if( x.m_args[i].m_step != nullptr ) { - if( x.m_args[i].m_left == nullptr ) { - head.m_start = PassUtils::get_bound(x_arr_var, i + 1, "lbound", al); - } else { - head.m_start = x.m_args[i].m_left; - } - if( x.m_args[i].m_right == nullptr ) { - head.m_end = PassUtils::get_bound(x_arr_var, i + 1, "ubound", al); - } else { - head.m_end = x.m_args[i].m_right; - } + void replace_ArraySection(ASR::ArraySection_t* x) { + LCOMPILERS_ASSERT(current_scope != nullptr); + ASR::expr_t* x_arr_var = x->m_v; + std::string new_name = "~" + std::to_string(slice_counter) + "_slice"; + slice_counter += 1; + char* new_var_name = s2c(al, new_name); + ASR::asr_t* slice_asr = ASR::make_Variable_t(al, x->base.base.loc, current_scope, new_var_name, nullptr, 0, + ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default, + get_array_from_slice(x, x_arr_var), ASR::abiType::Source, ASR::accessType::Public, + ASR::presenceType::Required, false); + ASR::symbol_t* slice_sym = ASR::down_cast(slice_asr); + current_scope->add_symbol(std::string(new_var_name), slice_sym); + *current_expr = ASRUtils::EXPR(ASR::make_Var_t(al, x->base.base.loc, slice_sym)); + Vec idx_vars_target, idx_vars_value; + PassUtils::create_idx_vars(idx_vars_target, x->n_args, x->base.base.loc, al, current_scope, "_t"); + PassUtils::create_idx_vars(idx_vars_value, x->n_args, x->base.base.loc, al, current_scope, "_v"); + ASR::stmt_t* doloop = nullptr; + int a_kind = ASRUtils::extract_kind_from_ttype_t(ASRUtils::expr_type(x->m_v)); + ASR::ttype_t* int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x->base.base.loc, a_kind, nullptr, 0)); + ASR::expr_t* const_1 = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x->base.base.loc, 1, int_type)); + for( int i = (int) x->n_args - 1; i >= 0; i-- ) { + ASR::do_loop_head_t head; + head.m_v = idx_vars_value[i]; + if( x->m_args[i].m_step != nullptr ) { + if( x->m_args[i].m_left == nullptr ) { + head.m_start = PassUtils::get_bound(x_arr_var, i + 1, "lbound", al); } else { - head.m_start = x.m_args[i].m_right; - head.m_end = x.m_args[i].m_right; + head.m_start = x->m_args[i].m_left; } - head.m_increment = x.m_args[i].m_step; - head.loc = head.m_v->base.loc; - Vec doloop_body; - doloop_body.reserve(al, 1); - if( doloop == nullptr ) { - ASR::expr_t* target_ref = PassUtils::create_array_ref(slice_sym, idx_vars_target, al, x.base.base.loc, x.m_type); - ASR::expr_t* value_ref = PassUtils::create_array_ref(x.m_v, idx_vars_value, al); - ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, target_ref, value_ref, nullptr)); - doloop_body.push_back(al, assign_stmt); + if( x->m_args[i].m_right == nullptr ) { + head.m_end = PassUtils::get_bound(x_arr_var, i + 1, "ubound", al); } else { - ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_target[i+1], const_1, nullptr)); - doloop_body.push_back(al, set_to_one); - doloop_body.push_back(al, doloop); + head.m_end = x->m_args[i].m_right; } - ASR::expr_t* inc_expr = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, idx_vars_target[i], ASR::binopType::Add, const_1, int_type, nullptr)); - ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_target[i], inc_expr, nullptr)); + } else { + head.m_start = x->m_args[i].m_right; + head.m_end = x->m_args[i].m_right; + } + head.m_increment = x->m_args[i].m_step; + head.loc = head.m_v->base.loc; + Vec doloop_body; + doloop_body.reserve(al, 1); + if( doloop == nullptr ) { + ASR::expr_t* target_ref = PassUtils::create_array_ref(slice_sym, idx_vars_target, al, x->base.base.loc, x->m_type); + ASR::expr_t* value_ref = PassUtils::create_array_ref(x->m_v, idx_vars_value, al); + ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x->base.base.loc, target_ref, value_ref, nullptr)); doloop_body.push_back(al, assign_stmt); - doloop = ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size())); + } else { + ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, x->base.base.loc, idx_vars_target[i+1], const_1, nullptr)); + doloop_body.push_back(al, set_to_one); + doloop_body.push_back(al, doloop); } - ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_target[0], const_1, nullptr)); - pass_result.push_back(al, set_to_one); - pass_result.push_back(al, doloop); + ASR::expr_t* inc_expr = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x->base.base.loc, idx_vars_target[i], ASR::binopType::Add, const_1, int_type, nullptr)); + ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x->base.base.loc, idx_vars_target[i], inc_expr, nullptr)); + doloop_body.push_back(al, assign_stmt); + doloop = ASRUtils::STMT(ASR::make_DoLoop_t(al, x->base.base.loc, head, doloop_body.p, doloop_body.size())); } + ASR::stmt_t* set_to_one = ASRUtils::STMT(ASR::make_Assignment_t(al, x->base.base.loc, idx_vars_target[0], const_1, nullptr)); + pass_result.push_back(al, set_to_one); + pass_result.push_back(al, doloop); } - void visit_Assignment(const ASR::Assignment_t& x) { - if( (ASR::is_a(*ASRUtils::expr_type(x.m_target)) && - ASR::is_a(*x.m_value)) || - ASR::is_a(*x.m_value) ) { - return ; - } - this->visit_expr(*x.m_value); - // If any slicing happened then do loop must have been created - // So, the current assignment should be inserted into pass_result - // so that it doesn't get ignored. - if( pass_result.size() > 0 ) { - pass_result.push_back(al, const_cast(&(x.base))); - } - } +}; - void visit_IntegerBinOp(const ASR::IntegerBinOp_t& x) { - handle_BinOp(x); - } - void visit_RealBinOp(const ASR::RealBinOp_t& x) { - handle_BinOp(x); - } - void visit_ComplexBinOp(const ASR::ComplexBinOp_t& x) { - handle_BinOp(x); - } - void visit_LogicalBinOp(const ASR::LogicalBinOp_t& x) { - handle_BinOp(x); - } +class ArraySectionVisitor : public ASR::CallReplacerOnExpressionsVisitor +{ + private: + + Allocator& al; + ReplaceArraySection replacer; + Vec pass_result; - template - void handle_BinOp(const T& x) { - T& xx = const_cast(x); - create_slice_var = true; - slice_var = nullptr; - this->visit_expr(*x.m_left); - if( slice_var != nullptr ) { - xx.m_left = slice_var; - slice_var = nullptr; + public: + + ArraySectionVisitor(Allocator& al_) : + al(al_), replacer(al_, pass_result) { + pass_result.reserve(al_, 1); } - this->visit_expr(*x.m_right); - if( slice_var != nullptr ) { - xx.m_right = slice_var; - slice_var = nullptr; + + void call_replacer() { + replacer.current_expr = current_expr; + replacer.replace_expr(*current_expr); + replacer.current_scope = current_scope; } - create_slice_var = false; - } - void visit_Print(const ASR::Print_t& x) { - ASR::Print_t& xx = const_cast(x); - for( size_t i = 0; i < xx.n_values; i++ ) { - slice_var = nullptr; - create_slice_var = true; - this->visit_expr(*xx.m_values[i]); - create_slice_var = false; - if( slice_var != nullptr ) { - xx.m_values[i] = slice_var; + void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) { + Vec body; + body.reserve(al, n_body); + for (size_t i=0; i 0 ) { - pass_result.push_back(al, const_cast(&(x.base))); - } - } + }; void pass_replace_arr_slice(Allocator &al, ASR::TranslationUnit_t &unit, - const LCompilers::PassOptions& pass_options) { - std::string rl_path = pass_options.runtime_library_dir; - ArrSliceVisitor v(al, rl_path); + const LCompilers::PassOptions& /*pass_options*/) { + ArraySectionVisitor v(al); v.visit_TranslationUnit(unit); PassUtils::UpdateDependenciesVisitor u(al); u.visit_TranslationUnit(unit); } - } // namespace LCompilers