Skip to content
Closed
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ set(SRC
pass/unused_functions.cpp
pass/flip_sign.cpp
pass/div_to_mul.cpp
pass/replace_symbolic.cpp
pass/intrinsic_function.cpp
pass/fma.cpp
pass/loop_vectorise.cpp
Expand Down
1 change: 1 addition & 0 deletions src/libasr/gen_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"replace_implied_do_loops",
"replace_init_expr",
"inline_function_calls",
"replace_symbolic",
"replace_intrinsic_function",
"loop_unroll",
"loop_vectorise",
Expand Down
3 changes: 3 additions & 0 deletions src/libasr/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <libasr/pass/replace_arr_slice.h>
#include <libasr/pass/replace_flip_sign.h>
#include <libasr/pass/replace_div_to_mul.h>
#include <libasr/pass/replace_symbolic.h>
#include <libasr/pass/replace_intrinsic_function.h>
#include <libasr/pass/replace_fma.h>
#include <libasr/pass/loop_unroll.h>
Expand Down Expand Up @@ -71,6 +72,7 @@ namespace LCompilers {
{"global_stmts", &pass_wrap_global_stmts},
{"implied_do_loops", &pass_replace_implied_do_loops},
{"array_op", &pass_replace_array_op},
{"symbolic", &pass_replace_symbolic},
{"intrinsic_function", &pass_replace_intrinsic_function},
{"arr_slice", &pass_replace_arr_slice},
{"print_arr", &pass_replace_print_arr},
Expand Down Expand Up @@ -203,6 +205,7 @@ namespace LCompilers {
"subroutine_from_function",
"where",
"array_op",
"symbolic",
"intrinsic_function",
"array_op",
"pass_array_by_data",
Expand Down
175 changes: 175 additions & 0 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#include <libasr/asr.h>
#include <libasr/containers.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>
#include <libasr/asr_verify.h>
#include <libasr/pass/replace_symbolic.h>
#include <libasr/pass/pass_utils.h>

#include <vector>


namespace LCompilers {

using ASR::down_cast;
using ASR::is_a;


class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisitor>
{
public:
ReplaceSymbolicVisitor(Allocator &al_) :
PassVisitor(al_, nullptr) {
Comment on lines +21 to +22
Copy link
Collaborator

@Thirumalai-Shaktivel Thirumalai-Shaktivel Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this: #2200 (comment) doesn't work, pass m_global_scope as an argument to ReplaceSymbolicVisitor and pass that symtab to PassVisitor instead of nullptr.

pass_result.reserve(al, 1);
}

bool symbolic_replaces_with_CPtr_Module = false;
bool symbolic_replaces_with_CPtr_Function = false;

void visit_Module(const ASR::Module_t &x) {
SymbolTable* current_scope_copy = current_scope;
current_scope = x.m_symtab;
for (auto &a : x.m_symtab->get_scope()) {
this->visit_symbol(*a.second);
if(symbolic_replaces_with_CPtr_Module){
std::string new_name = current_scope->get_unique_name("basic_new_stack");
if(current_scope->get_symbol(new_name)) return;
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(current_scope);

Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::ttype_t *arg_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, x.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, arg_type, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)));
}
Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char *> dep;
dep.reserve(al, 1);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, x.base.base.loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false);
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
current_scope->add_symbol(new_name, new_symbol);
symbolic_replaces_with_CPtr_Module = false;
}
}
current_scope = current_scope_copy;
}

void visit_Function(const ASR::Function_t& x) {
ASR::Function_t &xx = const_cast<ASR::Function_t&>(x);
SymbolTable* current_scope_copy = current_scope;
current_scope = xx.m_symtab;
for (auto &item : current_scope->get_scope()) {
if (is_a<ASR::Variable_t>(*item.second)) {
this->visit_symbol(*item.second);
if(symbolic_replaces_with_CPtr_Function){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this operation in the visit_variable itself.
And remove both visit_Module and visit_Function.
I think PassVisitor handles everything, Even transform_stmts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can access the Function scope using current_scope and the Module scope using current_scope->parent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Store all the statements in pass_result.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so if we get rid of visit_Module and visit_Function ,
the scope for function is x.m_parent_symtab and that of module is x.m_parent_symtab->parent right ?

Copy link
Collaborator

@Thirumalai-Shaktivel Thirumalai-Shaktivel Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that would work, but let's use current_scope instead.
PassVisitor would have assigned current_scope with x.m_symtab in visit_Function (in asr.h: ASRPassBaseWalkVisitor).
So you can use current_scope to access the Function scope and current_scope-> parent to access the Module scope.

std::string var = std::string(item.first);
std::string placeholder = "_" + var;
ASR::symbol_t* var_sym = current_scope->get_symbol(var);
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym));

// statement 1
int cast_kind = ASR::cast_kindType::IntegerToInteger;
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4));
ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8));
ASR::expr_t* cast_tar = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type1));
ASR::expr_t* cast_val = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2));
ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc, cast_tar, (ASR::cast_kindType)cast_kind, type2, cast_val));

// statement 2
ASR::ttype_t *type3 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type3));

// statement 3
ASR::ttype_t *type4 = ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2));
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc, target1, type4, nullptr));
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node, type3, nullptr));

// statement 4
// ASR::symbol_t* basic_new_stack_sym = current_scope->parent->get_symbol("basic_new_stack");
// Vec<ASR::call_arg_t> call_args;
// call_args.reserve(al, 1);
// ASR::call_arg_t call_arg;
// call_arg.loc = xx.base.base.loc;
// call_arg.m_value = target2;
// call_args.push_back(al, call_arg);

// defining the assignment statement
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
//ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym, nullptr, call_args.p, call_args.n, nullptr));

// push stmt1 into the updated body vector
pass_result.push_back(al, stmt1);
pass_result.push_back(al, stmt2);
pass_result.push_back(al, stmt3);
//pass_result.push_back(al, stmt4);

// updated x.m_body and x.n_bdoy with that of the updated vector
// x.m_body = updated_body.p;
// x.n_body = updated_body.size();
transform_stmts(xx.m_body, xx.n_body);
symbolic_replaces_with_CPtr_Function = false;
}
}
}
current_scope = current_scope_copy;
}

void visit_Variable(const ASR::Variable_t& x) {
if (x.m_type->type == ASR::ttypeType::SymbolicExpression) {
symbolic_replaces_with_CPtr_Module = true;
symbolic_replaces_with_CPtr_Function = true;
std::string var_name = x.m_name;
std::string placeholder = "_" + std::string(x.m_name);

// defining CPtr variable
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::symbol_t* sym1 = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, x.base.base.loc, current_scope,
s2c(al, var_name), nullptr, 0,
x.m_intent, nullptr,
nullptr, x.m_storage,
type1, nullptr, x.m_abi,
x.m_access, x.m_presence,
x.m_value_attr));

ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::symbol_t* sym2 = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, x.base.base.loc, current_scope,
s2c(al, placeholder), nullptr, 0,
x.m_intent, nullptr,
nullptr, x.m_storage,
type2, nullptr, x.m_abi,
x.m_access, x.m_presence,
x.m_value_attr));

current_scope->erase_symbol(s2c(al, var_name));
current_scope->add_symbol(s2c(al, var_name), sym1);
current_scope->add_symbol(s2c(al, placeholder), sym2);
}
}
};

void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& /*pass_options*/) {
ReplaceSymbolicVisitor v(al);
v.visit_TranslationUnit(unit);
}

} // namespace LCompilers
14 changes: 14 additions & 0 deletions src/libasr/pass/replace_symbolic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef LIBASR_PASS_REPLACE_SYMBOLIC_H
#define LIBASR_PASS_REPLACE_SYMBOLIC_H

#include <libasr/asr.h>
#include <libasr/utils.h>

namespace LCompilers {

void pass_replace_symbolic(Allocator &al, ASR::TranslationUnit_t &unit,
const PassOptions &pass_options);

} // namespace LCompilers

#endif // LIBASR_PASS_REPLACE_SYMBOLIC_H